/src/llama.cpp/src/llama-context.cpp
Line | Count | Source |
1 | | #include "llama-context.h" |
2 | | |
3 | | #include "llama-arch.h" |
4 | | #include "llama-impl.h" |
5 | | #include "llama-batch.h" |
6 | | #include "llama-io.h" |
7 | | #include "llama-memory.h" |
8 | | #include "llama-mmap.h" |
9 | | #include "llama-model.h" |
10 | | |
11 | | #include <cinttypes> |
12 | | #include <cmath> |
13 | | #include <cstring> |
14 | | #include <limits> |
15 | | #include <stdexcept> |
16 | | |
17 | | // |
18 | | // llama_context |
19 | | // |
20 | | |
21 | | llama_context::llama_context( |
22 | | const llama_model & model, |
23 | | llama_context_params params) : |
24 | 0 | model(model), |
25 | 0 | balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) { |
26 | | // TODO warning when creating llama_context with awkward ctx size that is not a power of 2, |
27 | | // may need to be backend-dependent |
28 | 0 | LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); |
29 | |
|
30 | 0 | t_start_us = model.t_start_us; |
31 | 0 | t_load_us = model.t_load_us; |
32 | |
|
33 | 0 | const auto & hparams = model.hparams; |
34 | |
|
35 | 0 | cparams.n_seq_max = std::max(1u, params.n_seq_max); |
36 | 0 | if (cparams.n_seq_max > LLAMA_MAX_SEQ) { |
37 | 0 | throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); |
38 | 0 | } |
39 | | |
40 | 0 | cparams.n_threads = params.n_threads; |
41 | 0 | cparams.n_threads_batch = params.n_threads_batch; |
42 | 0 | cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; |
43 | 0 | cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; |
44 | 0 | cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; |
45 | 0 | cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; |
46 | 0 | cparams.embeddings = params.embeddings; |
47 | 0 | cparams.offload_kqv = params.offload_kqv; |
48 | 0 | cparams.no_perf = params.no_perf; |
49 | 0 | cparams.pooling_type = params.pooling_type; |
50 | 0 | cparams.warmup = false; |
51 | |
|
52 | 0 | cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; |
53 | 0 | cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; |
54 | 0 | cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; |
55 | |
|
56 | 0 | cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : |
57 | 0 | hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn : |
58 | 0 | hparams.n_ctx_train; |
59 | |
|
60 | 0 | cparams.cb_eval = params.cb_eval; |
61 | 0 | cparams.cb_eval_user_data = params.cb_eval_user_data; |
62 | | |
63 | | // Initialize backend samplers here so they are part of the sampling graph |
64 | | // before the reserve passes run later in this function. This avoids a later |
65 | | // re-reserve when graph nodes change. |
66 | 0 | if (params.samplers != nullptr && params.n_samplers > 0) { |
67 | 0 | for (size_t i = 0; i < params.n_samplers; ++i) { |
68 | 0 | const auto & config = params.samplers[i]; |
69 | |
|
70 | 0 | if (llama_sampler_chain_get(config.sampler, -1) == nullptr) { |
71 | 0 | throw std::runtime_error("the backend samplers must be of type llama_sampler_chain"); |
72 | 0 | } |
73 | | |
74 | 0 | if (set_sampler(config.seq_id, config.sampler)) { |
75 | 0 | const int n_samplers = llama_sampler_chain_n(config.sampler); |
76 | |
|
77 | 0 | LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers); |
78 | 0 | } |
79 | 0 | } |
80 | 0 | } |
81 | | |
82 | 0 | auto rope_scaling_type = params.rope_scaling_type; |
83 | 0 | if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { |
84 | 0 | rope_scaling_type = hparams.rope_scaling_type_train; |
85 | 0 | } |
86 | |
|
87 | 0 | if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) { |
88 | 0 | cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none |
89 | 0 | } |
90 | |
|
91 | 0 | if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set' |
92 | 0 | cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; |
93 | 0 | } |
94 | |
|
95 | 0 | if (cparams.yarn_ext_factor != 0) { |
96 | 0 | static auto get_mscale = [](float scale, float mscale) { |
97 | 0 | return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); |
98 | 0 | }; |
99 | |
|
100 | 0 | const float factor = 1.0f / cparams.rope_freq_scale; |
101 | | |
102 | | // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348 |
103 | 0 | if (hparams.rope_yarn_log_mul != 0.0f) { |
104 | | // note: here we assume `mscale == 1.0f` |
105 | | // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f |
106 | 0 | float mscale = 1.0f; |
107 | 0 | const float mscale_all_dims = hparams.rope_yarn_log_mul; |
108 | | |
109 | | // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] |
110 | | // special-case DEEPSEEK v2: |
111 | | // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43 |
112 | 0 | if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) { |
113 | 0 | mscale = mscale_all_dims; |
114 | 0 | } |
115 | |
|
116 | 0 | cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); |
117 | |
|
118 | 0 | LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n", |
119 | 0 | __func__, cparams.yarn_attn_factor, mscale, mscale_all_dims); |
120 | 0 | } else { |
121 | 0 | cparams.yarn_attn_factor = get_mscale(factor, 1.0f); |
122 | 0 | } |
123 | | |
124 | | // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor: |
125 | | // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544 |
126 | | // |
127 | | // ref: https://github.com/ggml-org/llama.cpp/discussions/7416 |
128 | | // https://github.com/ggml-org/llama.cpp/pull/17945 |
129 | 0 | cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor)); |
130 | 0 | } |
131 | |
|
132 | 0 | cparams.yarn_attn_factor *= hparams.rope_attn_factor; |
133 | |
|
134 | 0 | if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { |
135 | 0 | if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { |
136 | 0 | cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; |
137 | 0 | } else { |
138 | 0 | cparams.pooling_type = hparams.pooling_type; |
139 | 0 | } |
140 | 0 | } |
141 | |
|
142 | 0 | if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) { |
143 | 0 | cparams.causal_attn = hparams.causal_attn; |
144 | 0 | } else { |
145 | 0 | cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; |
146 | 0 | } |
147 | |
|
148 | 0 | cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; |
149 | | |
150 | | // with causal attention, the batch size is limited by the context size |
151 | 0 | cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; |
152 | |
|
153 | 0 | cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); |
154 | |
|
155 | 0 | cparams.op_offload = params.op_offload; |
156 | 0 | cparams.kv_unified = params.kv_unified; |
157 | |
|
158 | 0 | { |
159 | 0 | const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); |
160 | 0 | graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; |
161 | |
|
162 | 0 | if (graph_reuse_disable) { |
163 | 0 | LLAMA_LOG_WARN("%s: graph reuse disabled\n", __func__); |
164 | 0 | } |
165 | 0 | } |
166 | | |
167 | | // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732 |
168 | 0 | cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); |
169 | |
|
170 | 0 | if (cparams.kv_unified) { |
171 | 0 | cparams.n_ctx_seq = cparams.n_ctx; |
172 | 0 | } else { |
173 | 0 | cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max; |
174 | 0 | cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256); |
175 | |
|
176 | 0 | if (cparams.n_ctx_seq == 0) { |
177 | 0 | throw std::runtime_error("n_ctx_seq == 0"); |
178 | 0 | } |
179 | | |
180 | 0 | if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) { |
181 | 0 | cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max; |
182 | 0 | LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx); |
183 | 0 | } |
184 | 0 | } |
185 | | |
186 | 0 | LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); |
187 | 0 | LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); |
188 | 0 | LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq); |
189 | 0 | LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); |
190 | 0 | LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); |
191 | 0 | LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); |
192 | 0 | LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type)); |
193 | 0 | LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); |
194 | 0 | LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); |
195 | 0 | LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); |
196 | |
|
197 | 0 | if (cparams.n_ctx_seq < hparams.n_ctx_train) { |
198 | 0 | LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", |
199 | 0 | __func__, cparams.n_ctx_seq, hparams.n_ctx_train); |
200 | 0 | } |
201 | |
|
202 | 0 | if (cparams.n_ctx_seq > hparams.n_ctx_train) { |
203 | 0 | LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", |
204 | 0 | __func__, cparams.n_ctx_seq, hparams.n_ctx_train); |
205 | 0 | } |
206 | |
|
207 | 0 | if (!hparams.vocab_only) { |
208 | | // GPU backends |
209 | 0 | for (auto * dev : model.devices) { |
210 | 0 | ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); |
211 | 0 | if (backend == nullptr) { |
212 | 0 | throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); |
213 | 0 | } |
214 | 0 | backends.emplace_back(backend); |
215 | 0 | } |
216 | | |
217 | | // add ACCEL backends (such as BLAS) |
218 | 0 | for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { |
219 | 0 | ggml_backend_dev_t dev = ggml_backend_dev_get(i); |
220 | 0 | if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { |
221 | 0 | ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); |
222 | 0 | if (backend == nullptr) { |
223 | 0 | throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); |
224 | 0 | } |
225 | 0 | backends.emplace_back(backend); |
226 | 0 | } |
227 | 0 | } |
228 | | |
229 | | // add CPU backend |
230 | 0 | backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); |
231 | 0 | if (backend_cpu == nullptr) { |
232 | 0 | throw std::runtime_error("failed to initialize CPU backend"); |
233 | 0 | } |
234 | 0 | backends.emplace_back(backend_cpu); |
235 | | |
236 | | // create a list of the set_n_threads functions in the backends |
237 | 0 | for (auto & backend : backends) { |
238 | 0 | ggml_backend_dev_t dev = ggml_backend_get_device(backend.get()); |
239 | 0 | ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; |
240 | 0 | if (reg) { |
241 | 0 | auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); |
242 | 0 | if (ggml_backend_set_n_threads_fn) { |
243 | 0 | set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn); |
244 | 0 | } |
245 | 0 | } |
246 | 0 | } |
247 | |
|
248 | 0 | llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data); |
249 | | |
250 | | // graph outputs buffer |
251 | 0 | { |
252 | | // resized during inference when a batch uses more outputs |
253 | | // Create a dummy batch for initialization. |
254 | 0 | llama_batch dummy_batch = {}; |
255 | 0 | dummy_batch.n_tokens = 0; |
256 | 0 | if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) { |
257 | 0 | throw std::runtime_error("failed to reserve initial output buffer"); |
258 | 0 | } |
259 | | |
260 | 0 | LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__, |
261 | 0 | ggml_backend_buffer_name (buf_output.get()), |
262 | 0 | ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0); |
263 | 0 | } |
264 | 0 | } |
265 | | |
266 | | // init the memory module |
267 | 0 | if (!hparams.vocab_only) { |
268 | 0 | llama_memory_params params_mem = { |
269 | 0 | /*.type_k =*/ params.type_k, |
270 | 0 | /*.type_v =*/ params.type_v, |
271 | 0 | /*.swa_full =*/ params.swa_full, |
272 | 0 | }; |
273 | |
|
274 | 0 | memory.reset(model.create_memory(params_mem, cparams)); |
275 | 0 | } |
276 | | |
277 | | // init backends |
278 | 0 | if (!hparams.vocab_only) { |
279 | 0 | LLAMA_LOG_DEBUG("%s: enumerating backends\n", __func__); |
280 | |
|
281 | 0 | backend_buft.clear(); |
282 | 0 | backend_ptrs.clear(); |
283 | 0 | backend_buf_exp_size.clear(); |
284 | |
|
285 | 0 | for (auto & backend : backends) { |
286 | 0 | auto * buft = ggml_backend_get_default_buffer_type(backend.get()); |
287 | 0 | auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); |
288 | |
|
289 | 0 | if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) { |
290 | | // use the host buffer of the first device CPU for faster transfer of the intermediate state |
291 | 0 | auto * dev = model.devices[0]; |
292 | 0 | auto * host_buft = ggml_backend_dev_host_buffer_type(dev); |
293 | 0 | if (host_buft) { |
294 | 0 | buft = host_buft; |
295 | 0 | } |
296 | 0 | } |
297 | |
|
298 | 0 | backend_buft.push_back(buft); |
299 | 0 | backend_ptrs.push_back(backend.get()); |
300 | 0 | backend_buf_exp_size.push_back(0); |
301 | 0 | } |
302 | |
|
303 | 0 | LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); |
304 | |
|
305 | 0 | const uint32_t n_seqs = cparams.n_seq_max; |
306 | 0 | const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); |
307 | |
|
308 | 0 | const size_t max_nodes = this->graph_max_nodes(n_tokens); |
309 | |
|
310 | 0 | LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); |
311 | |
|
312 | 0 | gf_res_prev.reset(new llm_graph_result(max_nodes)); |
313 | 0 | gf_res_reserve.reset(new llm_graph_result(max_nodes)); |
314 | | |
315 | | // TODO: move these checks to ggml_backend_sched |
316 | | // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary |
317 | 0 | bool pipeline_parallel = |
318 | 0 | model.n_devices() > 1 && |
319 | 0 | model.n_gpu_layers() > model.hparams.n_layer && |
320 | 0 | model.split_mode() == LLAMA_SPLIT_MODE_LAYER && |
321 | 0 | cparams.offload_kqv && |
322 | 0 | !model.has_tensor_overrides(); |
323 | | |
324 | | // pipeline parallelism requires support for async compute and events in all devices |
325 | 0 | if (pipeline_parallel) { |
326 | 0 | for (auto & backend : backends) { |
327 | 0 | auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); |
328 | 0 | if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { |
329 | | // ignore CPU backend |
330 | 0 | continue; |
331 | 0 | } |
332 | 0 | auto * dev = ggml_backend_get_device(backend.get()); |
333 | 0 | ggml_backend_dev_props props; |
334 | 0 | ggml_backend_dev_get_props(dev, &props); |
335 | 0 | if (!props.caps.async || !props.caps.events) { |
336 | | // device does not support async compute or events |
337 | 0 | pipeline_parallel = false; |
338 | 0 | break; |
339 | 0 | } |
340 | 0 | } |
341 | 0 | } |
342 | |
|
343 | 0 | sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); |
344 | |
|
345 | 0 | if (pipeline_parallel) { |
346 | 0 | LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); |
347 | 0 | } |
348 | |
|
349 | 0 | llama_memory_context_ptr mctx; |
350 | 0 | if (memory) { |
351 | 0 | LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); |
352 | 0 | mctx = memory->init_full(); |
353 | 0 | if (!mctx) { |
354 | 0 | throw std::runtime_error("failed to initialize memory module"); |
355 | 0 | } |
356 | 0 | } |
357 | | |
358 | 0 | cross.v_embd.clear(); |
359 | | |
360 | | // avoid reserving graphs with zero outputs - assume one output per sequence |
361 | 0 | n_outputs = n_seqs; |
362 | |
|
363 | 0 | LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); |
364 | | |
365 | | // resolve automatic Flash Attention use |
366 | 0 | if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { |
367 | 0 | auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); |
368 | 0 | if (!gf) { |
369 | 0 | throw std::runtime_error("failed to split graph for Flash Attention check"); |
370 | 0 | } |
371 | | |
372 | 0 | const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; |
373 | 0 | bool fa_device_mismatch = false; |
374 | 0 | for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { |
375 | 0 | ggml_tensor * n = ggml_graph_node(gf, i); |
376 | 0 | if (n->op != GGML_OP_FLASH_ATTN_EXT) { |
377 | 0 | continue; |
378 | 0 | } |
379 | 0 | ggml_backend_dev_t device_fa = ggml_backend_get_device( |
380 | 0 | ggml_backend_sched_get_tensor_backend(sched.get(), n)); |
381 | | |
382 | | // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer |
383 | 0 | GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); |
384 | 0 | const int il = std::stoi(n->name + prefix_len); |
385 | 0 | ggml_backend_dev_t device_kv = model.dev_layer(il); |
386 | 0 | if (device_fa != device_kv) { |
387 | 0 | LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " |
388 | 0 | "is assigned to device %s (usually due to missing support)\n", |
389 | 0 | __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); |
390 | | // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways |
391 | 0 | fa_device_mismatch = true; |
392 | 0 | break; |
393 | 0 | } |
394 | 0 | } |
395 | 0 | if (fa_device_mismatch) { |
396 | 0 | cparams.flash_attn = false; |
397 | 0 | LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); |
398 | 0 | if (ggml_is_quantized(params.type_v)) { |
399 | 0 | throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); |
400 | 0 | } |
401 | 0 | } else { |
402 | 0 | cparams.flash_attn = true; |
403 | 0 | LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); |
404 | 0 | } |
405 | 0 | } |
406 | | |
407 | | // reserve worst-case graph |
408 | 0 | int n_splits_pp = -1; |
409 | 0 | int n_nodes_pp = -1; |
410 | |
|
411 | 0 | int n_splits_tg = -1; |
412 | 0 | int n_nodes_tg = -1; |
413 | | |
414 | | // reserve pp (prompt processing) graph first so that buffers are only allocated once |
415 | 0 | { |
416 | 0 | auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), |
417 | 0 | model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); |
418 | 0 | if (!gf) { |
419 | 0 | if (pipeline_parallel) { |
420 | 0 | LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); |
421 | 0 | sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); |
422 | 0 | gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); |
423 | 0 | } |
424 | 0 | if (!gf) { |
425 | 0 | throw std::runtime_error("failed to allocate compute pp buffers"); |
426 | 0 | } |
427 | 0 | } |
428 | | |
429 | 0 | n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); |
430 | 0 | n_nodes_pp = ggml_graph_n_nodes(gf); |
431 | 0 | } |
432 | | |
433 | | // reserve with tg (token generation) graph to get the number of splits and nodes |
434 | 0 | { |
435 | 0 | auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); |
436 | 0 | if (!gf) { |
437 | 0 | throw std::runtime_error("failed to allocate compute tg buffers"); |
438 | 0 | } |
439 | | |
440 | 0 | n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); |
441 | 0 | n_nodes_tg = ggml_graph_n_nodes(gf); |
442 | 0 | } |
443 | | |
444 | | // reserve again with pp graph to avoid ggml-alloc reallocations during inference |
445 | 0 | { |
446 | | // TODO: not sure if the following graph would be worster case for multi-stream KV caches: |
447 | | // |
448 | | // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); |
449 | | // |
450 | 0 | auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); |
451 | 0 | if (!gf) { |
452 | 0 | throw std::runtime_error("failed to allocate compute pp buffers"); |
453 | 0 | } |
454 | 0 | } |
455 | | |
456 | 0 | for (size_t i = 0; i < backend_ptrs.size(); ++i) { |
457 | 0 | ggml_backend_t backend = backend_ptrs[i]; |
458 | 0 | ggml_backend_buffer_type_t buft = backend_buft[i]; |
459 | 0 | if (!model.hparams.no_alloc) { |
460 | 0 | backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); |
461 | 0 | } |
462 | 0 | if (backend_buf_exp_size[i] > 1) { |
463 | 0 | LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, |
464 | 0 | ggml_backend_buft_name(buft), |
465 | 0 | backend_buf_exp_size[i] / 1024.0 / 1024.0); |
466 | 0 | } |
467 | 0 | } |
468 | |
|
469 | 0 | if (n_nodes_pp == n_nodes_tg) { |
470 | 0 | LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); |
471 | 0 | } else { |
472 | 0 | LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); |
473 | 0 | } |
474 | |
|
475 | 0 | if (n_splits_pp == n_splits_tg) { |
476 | 0 | LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); |
477 | 0 | } else { |
478 | 0 | LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); |
479 | 0 | } |
480 | 0 | } |
481 | | |
482 | | // Initialize the full vocabulary token ids for backend samplers. |
483 | 0 | { |
484 | 0 | const int n_vocab = model.vocab.n_tokens(); |
485 | |
|
486 | 0 | sampling.token_ids_full_vocab.resize(n_vocab); |
487 | 0 | for (int i = 0; i < n_vocab; ++i) { |
488 | 0 | sampling.token_ids_full_vocab[i] = i; |
489 | 0 | } |
490 | 0 | } |
491 | 0 | } |
492 | | |
493 | 0 | llama_context::~llama_context() { |
494 | 0 | if (!model.hparams.no_alloc) { |
495 | 0 | for (size_t i = 0; i < backend_ptrs.size(); ++i) { |
496 | 0 | ggml_backend_t backend = backend_ptrs[i]; |
497 | 0 | ggml_backend_buffer_type_t buft = backend_buft[i]; |
498 | |
|
499 | 0 | const size_t size_exp = backend_buf_exp_size[i]; |
500 | 0 | const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); |
501 | 0 | if (size_exp == size_act) { |
502 | 0 | LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", |
503 | 0 | __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); |
504 | 0 | } else { |
505 | 0 | LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", |
506 | 0 | __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); |
507 | 0 | } |
508 | 0 | } |
509 | 0 | } |
510 | 0 | ggml_opt_free(opt_ctx); |
511 | 0 | } |
512 | | |
513 | 0 | void llama_context::synchronize() { |
514 | 0 | ggml_backend_sched_synchronize(sched.get()); |
515 | | |
516 | | // FIXME: if multiple single tokens are evaluated without a synchronization, |
517 | | // the stats will be added to the prompt evaluation stats |
518 | | // this should only happen when using batch size 1 to evaluate a batch |
519 | | |
520 | | // add the evaluation to the stats |
521 | 0 | if (n_queued_tokens == 1) { |
522 | 0 | if (!cparams.no_perf) { |
523 | 0 | t_eval_us += ggml_time_us() - t_compute_start_us; |
524 | 0 | } |
525 | 0 | n_eval++; |
526 | 0 | } else if (n_queued_tokens > 1) { |
527 | 0 | if (!cparams.no_perf) { |
528 | 0 | t_p_eval_us += ggml_time_us() - t_compute_start_us; |
529 | 0 | } |
530 | 0 | n_p_eval += n_queued_tokens; |
531 | 0 | } |
532 | | |
533 | | // get a more accurate load time, upon first eval |
534 | 0 | if (n_queued_tokens > 0 && !has_evaluated_once) { |
535 | 0 | t_load_us = ggml_time_us() - t_start_us; |
536 | 0 | has_evaluated_once = true; |
537 | 0 | } |
538 | |
|
539 | 0 | n_queued_tokens = 0; |
540 | 0 | t_compute_start_us = 0; |
541 | 0 | } |
542 | | |
543 | 0 | const llama_model & llama_context::get_model() const { |
544 | 0 | return model; |
545 | 0 | } |
546 | | |
547 | 0 | const llama_cparams & llama_context::get_cparams() const { |
548 | 0 | return cparams; |
549 | 0 | } |
550 | | |
551 | 0 | ggml_backend_sched_t llama_context::get_sched() const { |
552 | 0 | return sched.get(); |
553 | 0 | } |
554 | | |
555 | 0 | uint32_t llama_context::n_ctx() const { |
556 | 0 | return cparams.n_ctx; |
557 | 0 | } |
558 | | |
559 | 0 | uint32_t llama_context::n_ctx_seq() const { |
560 | 0 | return cparams.n_ctx_seq; |
561 | 0 | } |
562 | | |
563 | 0 | uint32_t llama_context::n_batch() const { |
564 | 0 | return cparams.n_batch; |
565 | 0 | } |
566 | | |
567 | 0 | uint32_t llama_context::n_ubatch() const { |
568 | 0 | return cparams.n_ubatch; |
569 | 0 | } |
570 | | |
571 | 0 | uint32_t llama_context::n_seq_max() const { |
572 | 0 | return cparams.n_seq_max; |
573 | 0 | } |
574 | | |
575 | 0 | uint32_t llama_context::n_threads() const { |
576 | 0 | return cparams.n_threads; |
577 | 0 | } |
578 | | |
579 | 0 | uint32_t llama_context::n_threads_batch() const { |
580 | 0 | return cparams.n_threads_batch; |
581 | 0 | } |
582 | | |
583 | 0 | llama_memory_t llama_context::get_memory() const { |
584 | 0 | return memory.get(); |
585 | 0 | } |
586 | | |
587 | 0 | bool llama_context::memory_update(bool optimize) { |
588 | 0 | if (!memory) { |
589 | 0 | return false; |
590 | 0 | } |
591 | | |
592 | 0 | { |
593 | 0 | const auto mctx = memory->init_update(this, optimize); |
594 | 0 | switch (mctx->get_status()) { |
595 | 0 | case LLAMA_MEMORY_STATUS_SUCCESS: |
596 | 0 | { |
597 | | // noop |
598 | 0 | } break; |
599 | 0 | case LLAMA_MEMORY_STATUS_NO_UPDATE: |
600 | 0 | { |
601 | | // no updates need to be performed |
602 | 0 | return false; |
603 | 0 | } |
604 | 0 | case LLAMA_MEMORY_STATUS_FAILED_PREPARE: |
605 | 0 | case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: |
606 | 0 | { |
607 | 0 | LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__); |
608 | 0 | return false; |
609 | 0 | } |
610 | 0 | } |
611 | | |
612 | | // reset the previous graph result to make sure that it won't be reused |
613 | | // TODO: change the mctx->apply() to return information if a graph reserve is needed |
614 | | // reset the graph result only if the memory module did reset the scheduler |
615 | 0 | gf_res_prev->reset(); |
616 | |
|
617 | 0 | if (!mctx->apply()) { |
618 | 0 | LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); |
619 | 0 | } |
620 | 0 | } |
621 | | |
622 | | // if the memory module did any computation, we have to reserve a new worst-case graph |
623 | 0 | { |
624 | 0 | const auto mctx = memory->init_full(); |
625 | 0 | if (!mctx) { |
626 | 0 | throw std::runtime_error("failed to initialize memory context"); |
627 | 0 | } |
628 | | |
629 | 0 | const uint32_t n_seqs = cparams.n_seq_max; |
630 | 0 | const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); |
631 | |
|
632 | 0 | auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); |
633 | 0 | if (!gf) { |
634 | 0 | LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); |
635 | 0 | } |
636 | 0 | } |
637 | | |
638 | 0 | return true; |
639 | 0 | } |
640 | | |
641 | 0 | enum llama_pooling_type llama_context::pooling_type() const { |
642 | 0 | return cparams.pooling_type; |
643 | 0 | } |
644 | | |
645 | 0 | float * llama_context::get_logits() { |
646 | 0 | output_reorder(); |
647 | |
|
648 | 0 | return logits; |
649 | 0 | } |
650 | | |
651 | 0 | int64_t llama_context::output_resolve_row(int32_t i) const { |
652 | 0 | int64_t j = -1; |
653 | | |
654 | | // support negative indices (last output row) |
655 | 0 | if (i < 0) { |
656 | 0 | j = n_outputs + i; |
657 | 0 | if (j < 0) { |
658 | 0 | throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); |
659 | 0 | } |
660 | 0 | } else if ((size_t) i >= output_ids.size()) { |
661 | 0 | throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); |
662 | 0 | } else { |
663 | | // use output_ids to translate the batch token index into a row number |
664 | | // that holds this token's data. |
665 | 0 | j = output_ids[i]; |
666 | 0 | } |
667 | | |
668 | 0 | if (j < 0) { |
669 | | // the batch token was not configured to output anything |
670 | 0 | throw std::runtime_error(format("batch.logits[%d] != true", i)); |
671 | 0 | } |
672 | | |
673 | 0 | if (j >= n_outputs) { |
674 | 0 | throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); |
675 | 0 | } |
676 | | |
677 | 0 | return j; |
678 | 0 | } |
679 | | |
680 | 0 | float * llama_context::get_logits_ith(int32_t i) { |
681 | 0 | int64_t j = -1; |
682 | |
|
683 | 0 | output_reorder(); |
684 | |
|
685 | 0 | try { |
686 | 0 | if (logits == nullptr) { |
687 | 0 | throw std::runtime_error("no logits"); |
688 | 0 | } |
689 | | |
690 | | // TODO: use output_resolve_row() |
691 | 0 | if (i < 0) { |
692 | 0 | j = n_outputs + i; |
693 | 0 | if (j < 0) { |
694 | 0 | throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); |
695 | 0 | } |
696 | 0 | } else if ((size_t) i >= output_ids.size()) { |
697 | 0 | throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); |
698 | 0 | } else { |
699 | 0 | j = output_ids[i]; |
700 | 0 | } |
701 | | |
702 | 0 | if (j < 0) { |
703 | 0 | throw std::runtime_error(format("batch.logits[%d] != true", i)); |
704 | 0 | } |
705 | 0 | if (j >= n_outputs) { |
706 | | // This should not happen |
707 | 0 | throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); |
708 | 0 | } |
709 | | |
710 | 0 | return logits + j*model.vocab.n_tokens(); |
711 | 0 | } catch (const std::exception & err) { |
712 | 0 | LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); |
713 | | #ifndef NDEBUG |
714 | | GGML_ABORT("fatal error"); |
715 | | #else |
716 | 0 | return nullptr; |
717 | 0 | #endif |
718 | 0 | } |
719 | 0 | } |
720 | | |
721 | 0 | float * llama_context::get_embeddings() { |
722 | 0 | output_reorder(); |
723 | |
|
724 | 0 | return embd; |
725 | 0 | } |
726 | | |
727 | 0 | llama_token * llama_context::get_sampled_tokens() const{ |
728 | 0 | return sampling.sampled; |
729 | 0 | } |
730 | | |
731 | 0 | float * llama_context::get_embeddings_ith(int32_t i) { |
732 | 0 | int64_t j = -1; |
733 | |
|
734 | 0 | output_reorder(); |
735 | |
|
736 | 0 | try { |
737 | 0 | if (embd == nullptr) { |
738 | 0 | throw std::runtime_error("no embeddings"); |
739 | 0 | } |
740 | | |
741 | | // TODO: use output_resolve_row() |
742 | 0 | if (i < 0) { |
743 | 0 | j = n_outputs + i; |
744 | 0 | if (j < 0) { |
745 | 0 | throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); |
746 | 0 | } |
747 | 0 | } else if ((size_t) i >= output_ids.size()) { |
748 | 0 | throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); |
749 | 0 | } else { |
750 | 0 | j = output_ids[i]; |
751 | 0 | } |
752 | | |
753 | 0 | if (j < 0) { |
754 | 0 | throw std::runtime_error(format("batch.logits[%d] != true", i)); |
755 | 0 | } |
756 | 0 | if (j >= n_outputs) { |
757 | | // This should not happen |
758 | 0 | throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); |
759 | 0 | } |
760 | | |
761 | 0 | const uint32_t n_embd_out = model.hparams.get_n_embd_out(); |
762 | 0 | return embd + j*n_embd_out; |
763 | 0 | } catch (const std::exception & err) { |
764 | 0 | LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); |
765 | | #ifndef NDEBUG |
766 | | GGML_ABORT("fatal error"); |
767 | | #else |
768 | 0 | return nullptr; |
769 | 0 | #endif |
770 | 0 | } |
771 | 0 | } |
772 | | |
773 | 0 | float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { |
774 | 0 | auto it = embd_seq.find(seq_id); |
775 | 0 | if (it == embd_seq.end()) { |
776 | 0 | return nullptr; |
777 | 0 | } |
778 | | |
779 | 0 | return it->second.data(); |
780 | 0 | } |
781 | | |
782 | 0 | llama_token llama_context::get_sampled_token_ith(int32_t idx) { |
783 | 0 | output_reorder(); |
784 | |
|
785 | 0 | if (sampling.sampled == nullptr) { |
786 | 0 | return LLAMA_TOKEN_NULL; |
787 | 0 | } |
788 | | |
789 | 0 | try { |
790 | 0 | const int64_t row = output_resolve_row(idx); |
791 | 0 | GGML_ASSERT(row < (int64_t) sampling.sampled_size); |
792 | 0 | return sampling.sampled[row]; |
793 | 0 | } catch (const std::exception & err) { |
794 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what()); |
795 | 0 | return LLAMA_TOKEN_NULL; |
796 | 0 | } |
797 | 0 | } |
798 | | |
799 | 0 | float * llama_context::get_sampled_probs_ith(int32_t idx) { |
800 | 0 | output_reorder(); |
801 | |
|
802 | 0 | if (sampling.probs == nullptr) { |
803 | 0 | return nullptr; |
804 | 0 | } |
805 | | |
806 | 0 | try { |
807 | 0 | const int64_t row = output_resolve_row(idx); |
808 | 0 | if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) { |
809 | 0 | return nullptr; |
810 | 0 | } |
811 | 0 | return sampling.probs + row*model.vocab.n_tokens(); |
812 | 0 | } catch (const std::exception & err) { |
813 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what()); |
814 | 0 | return nullptr; |
815 | 0 | } |
816 | 0 | } |
817 | | |
818 | 0 | float * llama_context::get_sampled_logits_ith(int32_t idx) { |
819 | 0 | output_reorder(); |
820 | |
|
821 | 0 | if (sampling.logits == nullptr) { |
822 | 0 | return nullptr; |
823 | 0 | } |
824 | | |
825 | 0 | try { |
826 | 0 | const int64_t row = output_resolve_row(idx); |
827 | 0 | if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) { |
828 | 0 | return nullptr; |
829 | 0 | } |
830 | 0 | return sampling.logits + row*model.vocab.n_tokens(); |
831 | 0 | } catch (const std::exception & err) { |
832 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what()); |
833 | 0 | return nullptr; |
834 | 0 | } |
835 | 0 | } |
836 | | |
837 | 0 | const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { |
838 | 0 | output_reorder(); |
839 | |
|
840 | 0 | try { |
841 | 0 | const int64_t row = output_resolve_row(idx); |
842 | 0 | if (sampling.candidates != nullptr && |
843 | 0 | (size_t) row < sampling.candidates_count.size() && |
844 | 0 | sampling.candidates_count[row] > 0) { |
845 | 0 | return sampling.candidates + row*model.vocab.n_tokens(); |
846 | 0 | } |
847 | 0 | } catch (const std::exception & err) { |
848 | | // fallback to full vocab list |
849 | 0 | } |
850 | | |
851 | 0 | return sampling.token_ids_full_vocab.data(); |
852 | 0 | } |
853 | | |
854 | 0 | size_t llama_context::get_sampled_candidates_count(int32_t idx) { |
855 | 0 | output_reorder(); |
856 | |
|
857 | 0 | if (sampling.candidates == nullptr) { |
858 | 0 | return 0; |
859 | 0 | } |
860 | | |
861 | 0 | try { |
862 | 0 | const int64_t row = output_resolve_row(idx); |
863 | 0 | if ((size_t) row >= sampling.candidates_count.size()) { |
864 | 0 | return 0; |
865 | 0 | } |
866 | 0 | return sampling.candidates_count[row]; |
867 | 0 | } catch (const std::exception & err) { |
868 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what()); |
869 | 0 | return 0; |
870 | 0 | } |
871 | 0 | } |
872 | | |
873 | 0 | size_t llama_context::get_sampled_logits_count(int32_t idx) { |
874 | 0 | output_reorder(); |
875 | |
|
876 | 0 | if (sampling.logits == nullptr) { |
877 | 0 | return model.vocab.n_tokens(); |
878 | 0 | } |
879 | | |
880 | 0 | try { |
881 | 0 | const int64_t row = output_resolve_row(idx); |
882 | 0 | if ((size_t) row >= sampling.logits_count.size()) { |
883 | 0 | return 0; |
884 | 0 | } |
885 | 0 | return sampling.logits_count[row]; |
886 | 0 | } catch (const std::exception & err) { |
887 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what()); |
888 | 0 | return 0; |
889 | 0 | } |
890 | 0 | } |
891 | | |
892 | 0 | size_t llama_context::get_sampled_probs_count(int32_t idx) { |
893 | 0 | output_reorder(); |
894 | |
|
895 | 0 | if (sampling.probs == nullptr) { |
896 | 0 | return 0; |
897 | 0 | } |
898 | | |
899 | 0 | try { |
900 | 0 | const int64_t row = output_resolve_row(idx); |
901 | 0 | if ((size_t) row >= sampling.probs_count.size()) { |
902 | 0 | return 0; |
903 | 0 | } |
904 | 0 | return sampling.probs_count[row]; |
905 | 0 | } catch (const std::exception & err) { |
906 | 0 | LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what()); |
907 | 0 | return 0; |
908 | 0 | } |
909 | 0 | } |
910 | | |
911 | | |
912 | | void llama_context::attach_threadpool( |
913 | | ggml_threadpool_t threadpool, |
914 | 0 | ggml_threadpool_t threadpool_batch) { |
915 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
916 | |
|
917 | 0 | this->threadpool = threadpool; |
918 | 0 | this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool; |
919 | 0 | } |
920 | | |
921 | 0 | void llama_context::detach_threadpool() { |
922 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
923 | |
|
924 | 0 | this->threadpool = nullptr; |
925 | 0 | this->threadpool_batch = nullptr; |
926 | 0 | } |
927 | | |
928 | 0 | void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) { |
929 | 0 | LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch); |
930 | |
|
931 | 0 | cparams.n_threads = n_threads; |
932 | 0 | cparams.n_threads_batch = n_threads_batch; |
933 | 0 | } |
934 | | |
935 | 0 | void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) { |
936 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
937 | |
|
938 | 0 | this->abort_callback = abort_callback; |
939 | 0 | this->abort_callback_data = abort_callback_data; |
940 | |
|
941 | 0 | for (auto & backend : backends) { |
942 | 0 | auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); |
943 | 0 | auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); |
944 | 0 | if (set_abort_callback_fn) { |
945 | 0 | set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); |
946 | 0 | } |
947 | 0 | } |
948 | 0 | } |
949 | | |
950 | 0 | void llama_context::set_embeddings(bool value) { |
951 | 0 | LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); |
952 | |
|
953 | 0 | cparams.embeddings = value; |
954 | 0 | } |
955 | | |
956 | 0 | void llama_context::set_causal_attn(bool value) { |
957 | 0 | LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); |
958 | |
|
959 | 0 | cparams.causal_attn = value; |
960 | 0 | } |
961 | | |
962 | 0 | void llama_context::set_warmup(bool value) { |
963 | 0 | LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); |
964 | |
|
965 | 0 | cparams.warmup = value; |
966 | 0 | } |
967 | | |
968 | 0 | bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { |
969 | 0 | LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); |
970 | |
|
971 | 0 | const bool can_offload = |
972 | 0 | sampler && |
973 | 0 | sampler->iface->backend_init && |
974 | 0 | sampler->iface->backend_apply && |
975 | 0 | llama_sampler_chain_n(sampler) > 0; |
976 | |
|
977 | 0 | if (sampler && can_offload) { |
978 | 0 | ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output()); |
979 | 0 | auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output()); |
980 | 0 | if (host_buft) { |
981 | 0 | buft = host_buft; |
982 | 0 | } |
983 | |
|
984 | 0 | sampler->iface->backend_init(sampler, buft); |
985 | |
|
986 | 0 | sampling.samplers[seq_id] = sampler; |
987 | |
|
988 | 0 | return true; |
989 | 0 | } |
990 | | |
991 | 0 | if (sampler && !can_offload) { |
992 | 0 | LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id); |
993 | |
|
994 | 0 | sampling.samplers.erase(seq_id); |
995 | |
|
996 | 0 | return false; |
997 | 0 | } |
998 | | |
999 | 0 | sampling.samplers.erase(seq_id); |
1000 | |
|
1001 | 0 | return true; |
1002 | 0 | } |
1003 | | |
1004 | | void llama_context::set_adapter_lora( |
1005 | | llama_adapter_lora * adapter, |
1006 | 0 | float scale) { |
1007 | 0 | LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); |
1008 | |
|
1009 | 0 | loras[adapter] = scale; |
1010 | 0 | } |
1011 | | |
1012 | | bool llama_context::rm_adapter_lora( |
1013 | 0 | llama_adapter_lora * adapter) { |
1014 | 0 | LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); |
1015 | |
|
1016 | 0 | auto pos = loras.find(adapter); |
1017 | 0 | if (pos != loras.end()) { |
1018 | 0 | loras.erase(pos); |
1019 | 0 | return true; |
1020 | 0 | } |
1021 | | |
1022 | 0 | return false; |
1023 | 0 | } |
1024 | | |
1025 | 0 | void llama_context::clear_adapter_lora() { |
1026 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
1027 | |
|
1028 | 0 | loras.clear(); |
1029 | 0 | } |
1030 | | |
1031 | | bool llama_context::apply_adapter_cvec( |
1032 | | const float * data, |
1033 | | size_t len, |
1034 | | int32_t n_embd, |
1035 | | int32_t il_start, |
1036 | 0 | int32_t il_end) { |
1037 | 0 | LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); |
1038 | |
|
1039 | 0 | return cvec.apply(model, data, len, n_embd, il_start, il_end); |
1040 | 0 | } |
1041 | | |
1042 | 0 | llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { |
1043 | 0 | if (mctx && !mctx->apply()) { |
1044 | 0 | LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); |
1045 | 0 | ret = GGML_STATUS_FAILED; |
1046 | 0 | return nullptr; |
1047 | 0 | } |
1048 | | |
1049 | 0 | auto * res = gf_res_prev.get(); |
1050 | 0 | auto * gf = res->get_gf(); |
1051 | | |
1052 | | // the new graph parameters |
1053 | | // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters |
1054 | 0 | const auto gparams = graph_params(res, ubatch, mctx, gtype); |
1055 | |
|
1056 | 0 | if (!graph_reuse_disable && res->can_reuse(gparams)) { |
1057 | | //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); |
1058 | |
|
1059 | 0 | n_reused++; |
1060 | 0 | } else { |
1061 | 0 | res->reset(); |
1062 | |
|
1063 | 0 | ggml_backend_sched_reset(sched.get()); |
1064 | 0 | ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); |
1065 | | |
1066 | | //const auto t_start_us = ggml_time_us(); |
1067 | |
|
1068 | 0 | gf = model.build_graph(gparams); |
1069 | | |
1070 | | //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); |
1071 | |
|
1072 | 0 | if (!gf) { |
1073 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); |
1074 | 0 | ret = GGML_STATUS_FAILED; |
1075 | 0 | return nullptr; |
1076 | 0 | } |
1077 | | |
1078 | 0 | if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { |
1079 | 0 | LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); |
1080 | 0 | ret = GGML_STATUS_ALLOC_FAILED; |
1081 | 0 | return nullptr; |
1082 | 0 | } |
1083 | 0 | } |
1084 | | |
1085 | | // set the input data for the input tensors |
1086 | 0 | { |
1087 | | //const auto t_start_us = ggml_time_us(); |
1088 | |
|
1089 | 0 | res->set_inputs(&ubatch); |
1090 | | |
1091 | | //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); |
1092 | 0 | } |
1093 | |
|
1094 | 0 | const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1); |
1095 | 0 | if (status != GGML_STATUS_SUCCESS) { |
1096 | 0 | LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); |
1097 | 0 | ret = status; |
1098 | 0 | return nullptr; |
1099 | 0 | } |
1100 | | |
1101 | 0 | ret = GGML_STATUS_SUCCESS; |
1102 | |
|
1103 | 0 | return res; |
1104 | 0 | } |
1105 | | |
1106 | 0 | int llama_context::encode(const llama_batch & batch_inp) { |
1107 | 0 | GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT |
1108 | |
|
1109 | 0 | if (batch_inp.n_tokens == 0) { |
1110 | 0 | LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); |
1111 | 0 | return -1; |
1112 | 0 | } |
1113 | | |
1114 | 0 | const auto & hparams = model.hparams; |
1115 | |
|
1116 | 0 | const int64_t n_embd = hparams.n_embd_inp(); |
1117 | 0 | const int64_t n_vocab = model.vocab.n_tokens(); |
1118 | | |
1119 | | // note: during encode, we always pass the full sequence starting from pos = 0 |
1120 | 0 | if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { |
1121 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); |
1122 | 0 | return -1; |
1123 | 0 | } |
1124 | | |
1125 | 0 | const uint32_t n_tokens = balloc->get_n_tokens(); |
1126 | | |
1127 | | // [TAG_NO_CACHE_PAD] |
1128 | | // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true |
1129 | 0 | const llama_ubatch ubatch = balloc->split_simple(n_tokens); |
1130 | | |
1131 | | // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot |
1132 | 0 | GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens"); |
1133 | |
|
1134 | 0 | if (t_compute_start_us == 0) { |
1135 | 0 | t_compute_start_us = ggml_time_us(); |
1136 | 0 | } |
1137 | | |
1138 | | // TODO: this clear of the buffer can easily be forgotten - need something better |
1139 | 0 | embd_seq.clear(); |
1140 | |
|
1141 | 0 | n_queued_tokens += n_tokens; |
1142 | | |
1143 | | // reserve output buffer |
1144 | 0 | if (output_reserve(n_tokens, batch_inp) < n_tokens) { |
1145 | 0 | LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); |
1146 | 0 | return -2; |
1147 | 0 | }; |
1148 | |
|
1149 | 0 | for (uint32_t i = 0; i < n_tokens; ++i) { |
1150 | 0 | output_ids[i] = i; |
1151 | 0 | } |
1152 | |
|
1153 | 0 | n_outputs = n_tokens; |
1154 | |
|
1155 | 0 | const auto causal_attn_org = cparams.causal_attn; |
1156 | | |
1157 | | // always use non-causal attention for encoder graphs |
1158 | | // TODO: this is a tmp solution until we have a proper way to support enc-dec models |
1159 | | // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 |
1160 | 0 | cparams.causal_attn = false; |
1161 | |
|
1162 | 0 | ggml_status status; |
1163 | 0 | const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); |
1164 | |
|
1165 | 0 | cparams.causal_attn = causal_attn_org; |
1166 | |
|
1167 | 0 | if (!res) { |
1168 | 0 | switch (status) { |
1169 | 0 | case GGML_STATUS_ABORTED: return 2; |
1170 | 0 | case GGML_STATUS_ALLOC_FAILED: return -2; |
1171 | 0 | case GGML_STATUS_FAILED: return -3; |
1172 | 0 | case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); |
1173 | 0 | } |
1174 | 0 | } |
1175 | | |
1176 | 0 | auto * t_logits = res->get_logits(); |
1177 | 0 | auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); |
1178 | | |
1179 | | // extract logits |
1180 | 0 | if (logits && t_logits) { |
1181 | 0 | ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); |
1182 | 0 | GGML_ASSERT(backend_res != nullptr); |
1183 | 0 | GGML_ASSERT(logits != nullptr); |
1184 | |
|
1185 | 0 | ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); |
1186 | 0 | } |
1187 | | |
1188 | | // extract embeddings |
1189 | 0 | if (embd && t_embd) { |
1190 | 0 | ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); |
1191 | 0 | GGML_ASSERT(backend_embd != nullptr); |
1192 | |
|
1193 | 0 | switch (cparams.pooling_type) { |
1194 | 0 | case LLAMA_POOLING_TYPE_NONE: |
1195 | 0 | { |
1196 | | // extract token embeddings |
1197 | 0 | GGML_ASSERT(embd != nullptr); |
1198 | 0 | const uint32_t n_embd_out = hparams.get_n_embd_out(); |
1199 | |
|
1200 | 0 | GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size); |
1201 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float)); |
1202 | 0 | } break; |
1203 | 0 | case LLAMA_POOLING_TYPE_MEAN: |
1204 | 0 | case LLAMA_POOLING_TYPE_CLS: |
1205 | 0 | case LLAMA_POOLING_TYPE_LAST: |
1206 | 0 | { |
1207 | | // extract sequence embeddings |
1208 | 0 | auto & embd_seq_out = embd_seq; |
1209 | |
|
1210 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1211 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1212 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1213 | |
|
1214 | 0 | embd_seq_out[seq_id].resize(n_embd); |
1215 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); |
1216 | 0 | } |
1217 | 0 | } break; |
1218 | 0 | case LLAMA_POOLING_TYPE_RANK: |
1219 | 0 | { |
1220 | | // extract the rerank score - n_cls_out floats per sequence |
1221 | 0 | auto & embd_seq_out = embd_seq; |
1222 | |
|
1223 | 0 | const uint32_t n_cls_out = hparams.n_cls_out; |
1224 | |
|
1225 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1226 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1227 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1228 | |
|
1229 | 0 | embd_seq_out[seq_id].resize(n_cls_out); |
1230 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); |
1231 | 0 | } |
1232 | 0 | } break; |
1233 | 0 | case LLAMA_POOLING_TYPE_UNSPECIFIED: |
1234 | 0 | { |
1235 | 0 | GGML_ABORT("unknown pooling type"); |
1236 | 0 | } |
1237 | 0 | } |
1238 | 0 | } |
1239 | | |
1240 | | // TODO: hacky solution |
1241 | 0 | if (model.arch == LLM_ARCH_T5 && t_embd) { |
1242 | | //cross.t_embd = t_embd; |
1243 | |
|
1244 | 0 | synchronize(); |
1245 | |
|
1246 | 0 | cross.n_embd = t_embd->ne[0]; |
1247 | 0 | cross.n_enc = t_embd->ne[1]; |
1248 | 0 | cross.v_embd.resize(cross.n_embd*cross.n_enc); |
1249 | 0 | memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); |
1250 | |
|
1251 | 0 | const auto & batch = balloc->get_batch(); |
1252 | | |
1253 | | // remember the sequence ids used during the encoding - needed for cross attention later |
1254 | 0 | cross.seq_ids_enc.resize(n_tokens); |
1255 | 0 | for (uint32_t i = 0; i < n_tokens; i++) { |
1256 | 0 | cross.seq_ids_enc[i].clear(); |
1257 | |
|
1258 | 0 | for (int s = 0; s < batch.n_seq_id[i]; s++) { |
1259 | 0 | const llama_seq_id seq_id = batch.seq_id[i][s]; |
1260 | |
|
1261 | 0 | cross.seq_ids_enc[i].insert(seq_id); |
1262 | 0 | } |
1263 | 0 | } |
1264 | 0 | } |
1265 | |
|
1266 | 0 | return 0; |
1267 | 0 | } |
1268 | | |
1269 | 0 | static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) { |
1270 | 0 | std::map<llama_seq_id, uint32_t> seq_to_row; |
1271 | | // how many output tokens we have seen so far for this ubatch. |
1272 | 0 | uint32_t local = 0; |
1273 | 0 | for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { |
1274 | | // skip tokens that are not output. |
1275 | 0 | if (!ubatch.output[i]) { |
1276 | 0 | continue; |
1277 | 0 | } |
1278 | | |
1279 | 0 | const llama_seq_id seq_id = ubatch.seq_id[i][0]; |
1280 | | // row_offset is the number of output tokens before this ubatch. |
1281 | 0 | seq_to_row[seq_id] = row_offset + local; |
1282 | 0 | ++local; |
1283 | 0 | } |
1284 | 0 | return seq_to_row; |
1285 | 0 | } |
1286 | | |
1287 | | static void copy_tensor_async_ints( |
1288 | | const std::map<llama_seq_id, ggml_tensor*> & tensor_map, |
1289 | | llama_token * sampled, |
1290 | | size_t sampled_size, |
1291 | | const std::map<llama_seq_id, uint32_t> & seq_to_row, |
1292 | 0 | ggml_backend_sched_t sched) { |
1293 | 0 | if (sampled == nullptr) { |
1294 | 0 | return; |
1295 | 0 | } |
1296 | | |
1297 | 0 | for (const auto & [seq_id, tensor] : tensor_map) { |
1298 | 0 | auto it = seq_to_row.find(seq_id); |
1299 | 0 | if (it == seq_to_row.end()) { |
1300 | 0 | continue; |
1301 | 0 | } |
1302 | | |
1303 | 0 | const uint32_t row = it->second; |
1304 | 0 | GGML_ASSERT(row < sampled_size); |
1305 | |
|
1306 | 0 | GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy"); |
1307 | |
|
1308 | 0 | ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); |
1309 | 0 | ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row])); |
1310 | 0 | } |
1311 | 0 | } |
1312 | | |
1313 | | static void copy_tensor_async_floats( |
1314 | | const std::map<llama_seq_id, ggml_tensor*> & tensor_map, |
1315 | | float * dst, |
1316 | | size_t stride, |
1317 | | std::vector<uint32_t> & counts, |
1318 | | const std::map<llama_seq_id, uint32_t> & seq_to_row, |
1319 | 0 | ggml_backend_sched_t sched) { |
1320 | 0 | if (dst == nullptr) { |
1321 | 0 | return; |
1322 | 0 | } |
1323 | | |
1324 | 0 | for (const auto & [seq_id, tensor] : tensor_map) { |
1325 | 0 | auto it = seq_to_row.find(seq_id); |
1326 | 0 | if (it == seq_to_row.end()) { |
1327 | 0 | continue; |
1328 | 0 | } |
1329 | | |
1330 | 0 | const uint32_t row = it->second; |
1331 | 0 | GGML_ASSERT(row < counts.size()); |
1332 | |
|
1333 | 0 | GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy"); |
1334 | |
|
1335 | 0 | ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); |
1336 | 0 | float * row_ptr = dst + (size_t) row * stride; |
1337 | 0 | ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); |
1338 | | |
1339 | | // Update the actual number of logits/probabilities that were written for this row. |
1340 | 0 | counts[row] = ggml_nelements(tensor); |
1341 | 0 | } |
1342 | 0 | } |
1343 | | |
1344 | | static void copy_tensor_async_candidates( |
1345 | | const std::map<llama_seq_id, ggml_tensor*> & tensor_map, |
1346 | | llama_token * dst, |
1347 | | size_t stride, |
1348 | | std::vector<uint32_t> & counts, |
1349 | | const std::map<llama_seq_id, uint32_t> & seq_to_row, |
1350 | 0 | ggml_backend_sched_t sched) { |
1351 | 0 | if (dst == nullptr) { |
1352 | 0 | return; |
1353 | 0 | } |
1354 | | |
1355 | 0 | for (const auto & [seq_id, tensor] : tensor_map) { |
1356 | 0 | auto it = seq_to_row.find(seq_id); |
1357 | 0 | if (it == seq_to_row.end()) { |
1358 | 0 | continue; |
1359 | 0 | } |
1360 | | |
1361 | 0 | const uint32_t row = it->second; |
1362 | 0 | GGML_ASSERT(row < counts.size()); |
1363 | |
|
1364 | 0 | GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy"); |
1365 | |
|
1366 | 0 | ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); |
1367 | 0 | llama_token * row_ptr = dst + (size_t) row * stride; |
1368 | 0 | ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); |
1369 | | |
1370 | | // Update the actual number of candidates that were written. |
1371 | 0 | counts[row] = ggml_nelements(tensor); |
1372 | 0 | } |
1373 | 0 | } |
1374 | | |
1375 | 0 | int llama_context::decode(const llama_batch & batch_inp) { |
1376 | 0 | GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT |
1377 | |
|
1378 | 0 | if (!memory) { |
1379 | 0 | LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); |
1380 | 0 | return encode(batch_inp); |
1381 | 0 | } |
1382 | | |
1383 | 0 | if (batch_inp.n_tokens == 0) { |
1384 | 0 | LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); |
1385 | 0 | return -1; |
1386 | 0 | } |
1387 | | |
1388 | 0 | const auto & vocab = model.vocab; |
1389 | 0 | const auto & hparams = model.hparams; |
1390 | |
|
1391 | 0 | const int64_t n_vocab = vocab.n_tokens(); |
1392 | 0 | const int64_t n_embd = hparams.n_embd_inp(); |
1393 | | |
1394 | | // when computing embeddings, all tokens are output |
1395 | 0 | const bool output_all = cparams.embeddings; |
1396 | 0 | const bool has_samplers = !sampling.samplers.empty(); |
1397 | |
|
1398 | 0 | const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max; |
1399 | | |
1400 | | // TODO: avoid this workaround in the future |
1401 | 0 | if (has_samplers && batch_inp.logits) { |
1402 | 0 | std::vector<int32_t> seq_output_count(n_seq_max, 0); |
1403 | |
|
1404 | 0 | for (int32_t i = 0; i < batch_inp.n_tokens; ++i) { |
1405 | 0 | if (batch_inp.logits[i] == 0) { |
1406 | 0 | continue; |
1407 | 0 | } |
1408 | | |
1409 | 0 | const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1; |
1410 | |
|
1411 | 0 | for (int32_t s = 0; s < ns; ++s) { |
1412 | 0 | const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0; |
1413 | |
|
1414 | 0 | seq_output_count[seq_id]++; |
1415 | 0 | if (seq_output_count[seq_id] > 1) { |
1416 | 0 | LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n", |
1417 | 0 | __func__, seq_id, seq_output_count[seq_id]); |
1418 | 0 | return -1; |
1419 | 0 | } |
1420 | 0 | } |
1421 | 0 | } |
1422 | 0 | } |
1423 | | |
1424 | 0 | if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) { |
1425 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); |
1426 | 0 | return -1; |
1427 | 0 | } |
1428 | | |
1429 | 0 | const uint32_t n_tokens_all = balloc->get_n_tokens(); |
1430 | 0 | const uint32_t n_outputs_all = balloc->get_n_outputs(); |
1431 | |
|
1432 | 0 | if (output_all) { |
1433 | | // require that all tokens are output |
1434 | 0 | if (n_outputs_all != n_tokens_all) { |
1435 | 0 | LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n", |
1436 | 0 | __func__, n_outputs_all, n_tokens_all); |
1437 | 0 | return -1; |
1438 | 0 | } |
1439 | 0 | } |
1440 | | |
1441 | 0 | GGML_ASSERT(n_tokens_all <= cparams.n_batch); |
1442 | |
|
1443 | 0 | GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); |
1444 | |
|
1445 | 0 | if (t_compute_start_us == 0) { |
1446 | 0 | t_compute_start_us = ggml_time_us(); |
1447 | 0 | } |
1448 | 0 | n_queued_tokens += n_tokens_all; |
1449 | | |
1450 | | // TODO: this clear of the buffer can easily be forgotten - need something better |
1451 | 0 | embd_seq.clear(); |
1452 | 0 | output_swaps.clear(); |
1453 | |
|
1454 | 0 | bool did_optimize = false; |
1455 | | |
1456 | | // handle any pending shifts/copies |
1457 | 0 | memory_update(false); |
1458 | |
|
1459 | 0 | llama_memory_context_ptr mctx; |
1460 | |
|
1461 | 0 | while (true) { |
1462 | 0 | mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); |
1463 | 0 | if (!mctx) { |
1464 | 0 | return -2; |
1465 | 0 | } |
1466 | | |
1467 | 0 | switch (mctx->get_status()) { |
1468 | 0 | case LLAMA_MEMORY_STATUS_SUCCESS: |
1469 | 0 | { |
1470 | 0 | } break; |
1471 | 0 | case LLAMA_MEMORY_STATUS_NO_UPDATE: |
1472 | 0 | { |
1473 | 0 | LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status()); |
1474 | |
|
1475 | 0 | return -2; |
1476 | 0 | } |
1477 | 0 | case LLAMA_MEMORY_STATUS_FAILED_PREPARE: |
1478 | 0 | { |
1479 | 0 | if (!did_optimize) { |
1480 | 0 | did_optimize = true; |
1481 | |
|
1482 | 0 | if (memory_update(true)) { |
1483 | 0 | LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens()); |
1484 | |
|
1485 | 0 | continue; |
1486 | 0 | } |
1487 | 0 | } |
1488 | | |
1489 | 0 | LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens()); |
1490 | |
|
1491 | 0 | return 1; |
1492 | 0 | } |
1493 | 0 | case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: |
1494 | 0 | { |
1495 | 0 | LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens()); |
1496 | |
|
1497 | 0 | return -2; |
1498 | 0 | } |
1499 | 0 | } |
1500 | | |
1501 | 0 | break; |
1502 | 0 | } |
1503 | | |
1504 | | // reserve output buffer |
1505 | 0 | if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { |
1506 | 0 | LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); |
1507 | 0 | return -2; |
1508 | 0 | }; |
1509 | |
|
1510 | 0 | int64_t n_outputs_prev = 0; |
1511 | |
|
1512 | 0 | do { |
1513 | 0 | const auto & ubatch = mctx->get_ubatch(); |
1514 | | |
1515 | | // count the outputs in this ubatch |
1516 | 0 | { |
1517 | 0 | int32_t n_outputs_new = 0; |
1518 | |
|
1519 | 0 | if (n_outputs_all == n_tokens_all) { |
1520 | 0 | n_outputs_new = ubatch.n_tokens; |
1521 | 0 | } else { |
1522 | 0 | for (uint32_t i = 0; i < ubatch.n_tokens; i++) { |
1523 | 0 | n_outputs_new += (int32_t) (ubatch.output[i] != 0); |
1524 | 0 | } |
1525 | 0 | } |
1526 | | |
1527 | | // needs to happen before the graph is built |
1528 | 0 | n_outputs = n_outputs_new; |
1529 | 0 | } |
1530 | |
|
1531 | 0 | ggml_status status; |
1532 | 0 | const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); |
1533 | |
|
1534 | 0 | if (!res) { |
1535 | | // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module |
1536 | 0 | llama_pos pos_min[LLAMA_MAX_SEQ]; |
1537 | 0 | for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
1538 | 0 | pos_min[s] = std::numeric_limits<llama_pos>::max(); |
1539 | 0 | } |
1540 | |
|
1541 | 0 | for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { |
1542 | 0 | const auto & seq_id = ubatch.seq_id[i][0]; |
1543 | |
|
1544 | 0 | pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); |
1545 | 0 | } |
1546 | |
|
1547 | 0 | for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
1548 | 0 | if (pos_min[s] == std::numeric_limits<llama_pos>::max()) { |
1549 | 0 | continue; |
1550 | 0 | } |
1551 | | |
1552 | 0 | LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); |
1553 | |
|
1554 | 0 | memory->seq_rm(s, pos_min[s], -1); |
1555 | 0 | } |
1556 | |
|
1557 | 0 | switch (status) { |
1558 | 0 | case GGML_STATUS_ABORTED: return 2; |
1559 | 0 | case GGML_STATUS_ALLOC_FAILED: return -2; |
1560 | 0 | case GGML_STATUS_FAILED: return -3; |
1561 | 0 | case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); |
1562 | 0 | } |
1563 | 0 | } |
1564 | | |
1565 | | // plot the computation graph in dot format (for debugging purposes) |
1566 | | //if (n_past%100 == 0) { |
1567 | | // ggml_graph_dump_dot(gf, NULL, "llama.dot"); |
1568 | | //} |
1569 | | |
1570 | 0 | auto * t_logits = res->get_logits(); |
1571 | 0 | auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; |
1572 | |
|
1573 | 0 | if (t_embd && res->get_embd_pooled()) { |
1574 | 0 | t_embd = res->get_embd_pooled(); |
1575 | 0 | } |
1576 | | |
1577 | | // extract logits |
1578 | | // For multi-sequence batches that mix backend samplers and CPU sampler |
1579 | | // this is currently inefficient as we copy all logits even for the |
1580 | | // backend sampled tokens. |
1581 | 0 | if (logits && t_logits && n_outputs > 0) { |
1582 | 0 | ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); |
1583 | 0 | GGML_ASSERT(backend_res != nullptr); |
1584 | 0 | GGML_ASSERT(logits != nullptr); |
1585 | |
|
1586 | 0 | float * logits_out = logits + n_outputs_prev*n_vocab; |
1587 | |
|
1588 | 0 | if (n_outputs) { |
1589 | 0 | GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); |
1590 | 0 | GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); |
1591 | 0 | ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); |
1592 | 0 | } |
1593 | 0 | } |
1594 | | |
1595 | | // extract embeddings |
1596 | 0 | if (embd && t_embd && n_outputs > 0) { |
1597 | 0 | ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); |
1598 | 0 | GGML_ASSERT(backend_embd != nullptr); |
1599 | |
|
1600 | 0 | switch (cparams.pooling_type) { |
1601 | 0 | case LLAMA_POOLING_TYPE_NONE: |
1602 | 0 | { |
1603 | | // extract token embeddings |
1604 | 0 | GGML_ASSERT(embd != nullptr); |
1605 | 0 | const uint32_t n_embd_out = hparams.get_n_embd_out(); |
1606 | 0 | float * embd_out = embd + n_outputs_prev*n_embd_out; |
1607 | |
|
1608 | 0 | if (n_outputs) { |
1609 | 0 | GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); |
1610 | 0 | GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size); |
1611 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float)); |
1612 | 0 | } |
1613 | 0 | } break; |
1614 | 0 | case LLAMA_POOLING_TYPE_MEAN: |
1615 | 0 | case LLAMA_POOLING_TYPE_CLS: |
1616 | 0 | case LLAMA_POOLING_TYPE_LAST: |
1617 | 0 | { |
1618 | | // extract sequence embeddings (cleared before processing each batch) |
1619 | 0 | auto & embd_seq_out = embd_seq; |
1620 | |
|
1621 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1622 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1623 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1624 | |
|
1625 | 0 | embd_seq_out[seq_id].resize(n_embd); |
1626 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); |
1627 | 0 | } |
1628 | 0 | } break; |
1629 | 0 | case LLAMA_POOLING_TYPE_RANK: |
1630 | 0 | { |
1631 | | // extract the rerank score - n_cls_out floats per sequence |
1632 | 0 | auto & embd_seq_out = embd_seq; |
1633 | |
|
1634 | 0 | const uint32_t n_cls_out = hparams.n_cls_out; |
1635 | |
|
1636 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1637 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1638 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1639 | |
|
1640 | 0 | embd_seq_out[seq_id].resize(n_cls_out); |
1641 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); |
1642 | 0 | } |
1643 | 0 | } break; |
1644 | 0 | case LLAMA_POOLING_TYPE_UNSPECIFIED: |
1645 | 0 | { |
1646 | 0 | GGML_ABORT("unknown pooling type"); |
1647 | 0 | } |
1648 | 0 | } |
1649 | 0 | } |
1650 | | |
1651 | | // This flag indicates whether a backend sampler has actually sampled a specific |
1652 | | // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings. |
1653 | 0 | const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty(); |
1654 | |
|
1655 | 0 | if (has_samplers && has_sampled) { |
1656 | 0 | const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); |
1657 | 0 | const auto stride = n_vocab; |
1658 | | |
1659 | | // async copy the sampling data from the backend to the host |
1660 | 0 | copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get()); |
1661 | |
|
1662 | 0 | copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); |
1663 | 0 | copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get()); |
1664 | 0 | copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get()); |
1665 | 0 | } |
1666 | |
|
1667 | 0 | n_outputs_prev += n_outputs; |
1668 | 0 | } while (mctx->next()); |
1669 | | |
1670 | | // set to total number of outputs in the batch, for use in llama_get_logits_ith |
1671 | 0 | n_outputs = n_outputs_all; |
1672 | | |
1673 | | // set output mappings |
1674 | 0 | if (n_outputs > 0) { |
1675 | 0 | bool sorted_output = true; |
1676 | |
|
1677 | 0 | auto & out_ids = balloc->get_out_ids(); |
1678 | |
|
1679 | 0 | GGML_ASSERT(out_ids.size() == (size_t) n_outputs); |
1680 | |
|
1681 | 0 | for (int64_t i = 0; i < n_outputs; ++i) { |
1682 | 0 | int64_t out_id = out_ids[i]; |
1683 | 0 | output_ids[out_id] = i; |
1684 | 0 | if (out_id != i) { |
1685 | 0 | sorted_output = false; |
1686 | 0 | } |
1687 | 0 | } |
1688 | | |
1689 | | // make the outputs have the same order they had in the user-provided batch |
1690 | | // note: this is mostly relevant for recurrent models atm |
1691 | 0 | if (!sorted_output && n_outputs > 1) { |
1692 | 0 | GGML_ASSERT((size_t) n_outputs == out_ids.size()); |
1693 | | |
1694 | | // TODO: is there something more efficient which also minimizes swaps? |
1695 | | // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) |
1696 | 0 | for (uint32_t i = 0; i < n_outputs - 1; ++i) { |
1697 | 0 | uint32_t j_min = i; |
1698 | 0 | for (uint32_t j = i + 1; j < n_outputs; ++j) { |
1699 | 0 | if (out_ids[j] < out_ids[j_min]) { |
1700 | 0 | j_min = j; |
1701 | 0 | } |
1702 | 0 | } |
1703 | 0 | if (j_min == i) { |
1704 | 0 | continue; |
1705 | 0 | } |
1706 | 0 | std::swap(out_ids[i], out_ids[j_min]); |
1707 | | |
1708 | | // remember the swaps and apply them lazily upon logits/embeddings access |
1709 | 0 | output_swaps.push_back({ i, j_min }); |
1710 | 0 | } |
1711 | |
|
1712 | 0 | std::fill(output_ids.begin(), output_ids.end(), -1); |
1713 | |
|
1714 | 0 | for (uint32_t i = 0; i < n_outputs; ++i) { |
1715 | 0 | output_ids[out_ids[i]] = i; |
1716 | 0 | } |
1717 | 0 | } |
1718 | 0 | } |
1719 | | |
1720 | | // wait for the computation to finish (automatically done when obtaining the model output) |
1721 | | //synchronize(); |
1722 | |
|
1723 | 0 | return 0; |
1724 | 0 | } |
1725 | | |
1726 | | // |
1727 | | // output |
1728 | | // |
1729 | | |
1730 | 0 | uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) { |
1731 | 0 | const auto & hparams = model.hparams; |
1732 | 0 | const auto & vocab = model.vocab; |
1733 | |
|
1734 | 0 | const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max()); |
1735 | |
|
1736 | 0 | const auto n_batch = cparams.n_batch; |
1737 | 0 | const auto n_vocab = vocab.n_tokens(); |
1738 | 0 | const auto n_embd_out = hparams.get_n_embd_out(); |
1739 | |
|
1740 | 0 | bool has_logits = true; |
1741 | 0 | bool has_embd = cparams.embeddings; |
1742 | | |
1743 | | // TODO: hacky enc-dec support |
1744 | 0 | if (model.arch == LLM_ARCH_T5) { |
1745 | 0 | has_logits = true; |
1746 | 0 | has_embd = true; |
1747 | 0 | } |
1748 | | |
1749 | | // Check which sampling modes are needed for the current batch. |
1750 | | // TODO: avoid this branching by working with the worst-case |
1751 | 0 | bool has_sampling = false; |
1752 | 0 | bool cpu_logits = false; |
1753 | |
|
1754 | 0 | if (batch.logits) { |
1755 | 0 | for (int32_t i = 0; i < batch.n_tokens; i++) { |
1756 | 0 | if (!batch.logits[i]) { |
1757 | 0 | continue; |
1758 | 0 | } |
1759 | 0 | for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { |
1760 | 0 | llama_seq_id seq_id = batch.seq_id[i][j]; |
1761 | 0 | if (sampling.samplers.find(seq_id) != sampling.samplers.end()) { |
1762 | 0 | has_sampling = true; |
1763 | 0 | } else { |
1764 | 0 | cpu_logits = true; |
1765 | 0 | } |
1766 | 0 | } |
1767 | 0 | } |
1768 | 0 | } else { |
1769 | | // When batch.logits is nullptr (when loading state with a dummy batch), |
1770 | | // allocate CPU logits. |
1771 | 0 | cpu_logits = true; |
1772 | 0 | } |
1773 | |
|
1774 | 0 | size_t backend_float_count = 0; |
1775 | 0 | size_t backend_token_count = 0; |
1776 | | |
1777 | | // Allocate CPU logits buffer only if needed by sequences in this batch |
1778 | 0 | logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0; |
1779 | 0 | embd_size = has_embd ? n_embd_out*n_outputs_max : 0; |
1780 | | |
1781 | | // TODO: avoid this branching by working with the worst-case |
1782 | 0 | if (!has_sampling) { |
1783 | 0 | sampling.logits_size = 0; |
1784 | 0 | sampling.probs_size = 0; |
1785 | 0 | sampling.sampled_size = 0; |
1786 | 0 | sampling.candidates_size = 0; |
1787 | 0 | } else { |
1788 | 0 | sampling.logits_size = n_vocab*n_outputs_max; |
1789 | 0 | sampling.probs_size = n_vocab*n_outputs_max; |
1790 | 0 | sampling.sampled_size = n_outputs_max; |
1791 | 0 | sampling.candidates_size = n_vocab*n_outputs_max; |
1792 | |
|
1793 | 0 | backend_float_count = sampling.logits_size + sampling.probs_size; |
1794 | 0 | backend_token_count = sampling.sampled_size + sampling.candidates_size; |
1795 | 0 | } |
1796 | |
|
1797 | 0 | if (output_ids.empty()) { |
1798 | | // init, never resized afterwards |
1799 | 0 | output_ids.resize(n_batch); |
1800 | 0 | } |
1801 | |
|
1802 | 0 | const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; |
1803 | 0 | const size_t new_size = |
1804 | 0 | (logits_size + embd_size + backend_float_count) * sizeof(float) + |
1805 | 0 | ( backend_token_count) * sizeof(llama_token); |
1806 | | |
1807 | | // alloc only when more than the current capacity is required |
1808 | | // TODO: also consider shrinking the buffer |
1809 | 0 | if (!buf_output || prev_size < new_size) { |
1810 | 0 | if (buf_output) { |
1811 | | #ifndef NDEBUG |
1812 | | // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) |
1813 | | LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); |
1814 | | #endif |
1815 | 0 | synchronize(); |
1816 | | |
1817 | | // TODO: not needed? |
1818 | 0 | buf_output = nullptr; |
1819 | 0 | logits = nullptr; |
1820 | 0 | embd = nullptr; |
1821 | 0 | } |
1822 | |
|
1823 | 0 | auto * buft = ggml_backend_cpu_buffer_type(); |
1824 | | // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory |
1825 | 0 | auto * output_dev = model.dev_output(); |
1826 | 0 | auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; |
1827 | 0 | if (output_dev_host_buft) { |
1828 | 0 | buft = output_dev_host_buft; |
1829 | 0 | } |
1830 | 0 | buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size)); |
1831 | 0 | if (buf_output == nullptr) { |
1832 | 0 | LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); |
1833 | 0 | return 0; |
1834 | 0 | } |
1835 | 0 | } |
1836 | | |
1837 | 0 | float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); |
1838 | |
|
1839 | 0 | logits = nullptr; |
1840 | 0 | embd = nullptr; |
1841 | |
|
1842 | 0 | size_t offset = 0; |
1843 | 0 | uint8_t * base = (uint8_t *) output_base; |
1844 | |
|
1845 | 0 | logits = (has_logits && cpu_logits) ? output_base : nullptr; |
1846 | 0 | offset += logits_size * sizeof(float); |
1847 | |
|
1848 | 0 | embd = has_embd ? (float *) (base + offset) : nullptr; |
1849 | 0 | offset += embd_size * sizeof(float); |
1850 | |
|
1851 | 0 | sampling.logits = nullptr; |
1852 | 0 | sampling.probs = nullptr; |
1853 | 0 | sampling.sampled = nullptr; |
1854 | 0 | sampling.candidates = nullptr; |
1855 | |
|
1856 | 0 | if (has_sampling) { |
1857 | 0 | sampling.logits = (float *) (base + offset); |
1858 | 0 | offset += sampling.logits_size * sizeof(float); |
1859 | |
|
1860 | 0 | sampling.probs = (float *) (base + offset); |
1861 | 0 | offset += sampling.probs_size * sizeof(float); |
1862 | |
|
1863 | 0 | sampling.sampled = (llama_token *) (base + offset); |
1864 | 0 | offset += sampling.sampled_size * sizeof(llama_token); |
1865 | |
|
1866 | 0 | sampling.candidates = (llama_token *) (base + offset); |
1867 | 0 | offset += sampling.candidates_size * sizeof(llama_token); |
1868 | | |
1869 | | // The count vectors keep track of the actual number of logits/probs/candidates |
1870 | | // copied from the backend for each output row. |
1871 | |
|
1872 | 0 | sampling.logits_count.resize(n_outputs_max); |
1873 | 0 | sampling.probs_count.resize(n_outputs_max); |
1874 | 0 | sampling.candidates_count.resize(n_outputs_max); |
1875 | |
|
1876 | 0 | std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0); |
1877 | 0 | std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); |
1878 | 0 | std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); |
1879 | |
|
1880 | 0 | std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL); |
1881 | 0 | } |
1882 | | |
1883 | | // set all ids as invalid (negative) |
1884 | 0 | std::fill(output_ids.begin(), output_ids.end(), -1); |
1885 | |
|
1886 | 0 | this->n_outputs = 0; |
1887 | |
|
1888 | 0 | return n_outputs_max; |
1889 | 0 | } |
1890 | | |
1891 | 0 | void llama_context::output_reorder() { |
1892 | 0 | const uint64_t n_vocab = model.vocab.n_tokens(); |
1893 | 0 | const uint64_t n_embd = model.hparams.n_embd; |
1894 | |
|
1895 | 0 | for (size_t s = 0; s < output_swaps.size(); ++s) { |
1896 | 0 | const uint64_t i0 = output_swaps[s].i0; |
1897 | 0 | const uint64_t i1 = output_swaps[s].i1; |
1898 | |
|
1899 | 0 | if (logits_size > 0) { |
1900 | 0 | for (uint64_t k = 0; k < n_vocab; k++) { |
1901 | 0 | std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); |
1902 | 0 | } |
1903 | 0 | } |
1904 | |
|
1905 | 0 | if (embd_size > 0) { |
1906 | 0 | for (uint64_t k = 0; k < n_embd; k++) { |
1907 | 0 | std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); |
1908 | 0 | } |
1909 | 0 | } |
1910 | |
|
1911 | 0 | if (sampling.logits && sampling.logits_size > 0) { |
1912 | 0 | for (uint64_t k = 0; k < n_vocab; ++k) { |
1913 | 0 | std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]); |
1914 | 0 | } |
1915 | 0 | } |
1916 | |
|
1917 | 0 | if (sampling.probs && sampling.probs_size > 0) { |
1918 | 0 | for (uint64_t k = 0; k < n_vocab; ++k) { |
1919 | 0 | std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]); |
1920 | 0 | } |
1921 | 0 | } |
1922 | |
|
1923 | 0 | if (sampling.candidates && sampling.candidates_size > 0) { |
1924 | 0 | for (uint64_t k = 0; k < n_vocab; ++k) { |
1925 | 0 | std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]); |
1926 | 0 | } |
1927 | 0 | } |
1928 | |
|
1929 | 0 | if (sampling.sampled && sampling.sampled_size > 0) { |
1930 | 0 | std::swap(sampling.sampled[i0], sampling.sampled[i1]); |
1931 | 0 | } |
1932 | |
|
1933 | 0 | if (!sampling.logits_count.empty()) { |
1934 | 0 | std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); |
1935 | 0 | } |
1936 | |
|
1937 | 0 | if (!sampling.probs_count.empty()) { |
1938 | 0 | std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); |
1939 | 0 | } |
1940 | |
|
1941 | 0 | if (!sampling.candidates_count.empty()) { |
1942 | 0 | std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); |
1943 | 0 | } |
1944 | 0 | } |
1945 | |
|
1946 | 0 | output_swaps.clear(); |
1947 | 0 | } |
1948 | | |
1949 | | // |
1950 | | // graph |
1951 | | // |
1952 | | |
1953 | 0 | uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { |
1954 | 0 | if (model.arch == LLM_ARCH_QWEN3NEXT) { |
1955 | 0 | return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors()); |
1956 | 0 | } |
1957 | 0 | uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors()); |
1958 | 0 | res += model.n_lora_nodes; |
1959 | 0 | return res; |
1960 | 0 | } |
1961 | | |
1962 | 0 | llm_graph_result * llama_context::get_gf_res_reserve() const { |
1963 | 0 | return static_cast<llm_graph_result *>(gf_res_reserve.get()); |
1964 | 0 | } |
1965 | | |
1966 | | ggml_cgraph * llama_context::graph_reserve( |
1967 | 0 | uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only, size_t * sizes) { |
1968 | 0 | LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); |
1969 | 0 | GGML_ASSERT(n_outputs >= 1); |
1970 | |
|
1971 | 0 | if (n_tokens % n_seqs != 0) { |
1972 | 0 | n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs |
1973 | 0 | n_outputs = std::max(n_outputs, n_tokens); |
1974 | |
|
1975 | 0 | LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); |
1976 | 0 | } |
1977 | |
|
1978 | 0 | ggml_backend_sched_reset(sched.get()); |
1979 | | |
1980 | | // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that |
1981 | 0 | gf_res_prev->reset(); |
1982 | | |
1983 | | // store the n_outputs as it is, and restore it afterwards |
1984 | | // TODO: not sure if needed, might simplify in the future by removing this |
1985 | 0 | const auto save_n_outputs = this->n_outputs; |
1986 | |
|
1987 | 0 | this->n_outputs = n_outputs; |
1988 | |
|
1989 | 0 | llama_batch_allocr balloc(model.hparams.n_pos_per_embd()); |
1990 | 0 | llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); |
1991 | | |
1992 | | // set one output token per sequence in order to activate all backend samplers |
1993 | 0 | std::vector<llama_seq_id> seq_ids(n_seqs); |
1994 | 0 | for (uint32_t i = 0; i < n_seqs; ++i) { |
1995 | 0 | seq_ids[i] = i; |
1996 | 0 | ubatch.n_seq_id[i] = 1; |
1997 | 0 | ubatch.seq_id[i] = &seq_ids[i]; |
1998 | 0 | ubatch.output[i] = true; |
1999 | 0 | } |
2000 | |
|
2001 | 0 | auto * res = gf_res_reserve.get(); |
2002 | |
|
2003 | 0 | const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); |
2004 | |
|
2005 | 0 | res->reset(); |
2006 | |
|
2007 | 0 | auto * gf = model.build_graph(gparams); |
2008 | |
|
2009 | 0 | this->n_outputs = save_n_outputs; |
2010 | | |
2011 | | // initialize scheduler with the specified graph |
2012 | 0 | if (split_only) { |
2013 | 0 | if (sizes) { |
2014 | 0 | ggml_backend_sched_reserve_size(sched.get(), gf, sizes); |
2015 | 0 | } else { |
2016 | 0 | ggml_backend_sched_split_graph(sched.get(), gf); |
2017 | 0 | } |
2018 | 0 | } else if (!ggml_backend_sched_reserve(sched.get(), gf)) { |
2019 | 0 | GGML_ASSERT(!sizes); |
2020 | 0 | LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); |
2021 | 0 | return nullptr; |
2022 | 0 | } |
2023 | | |
2024 | 0 | return gf; |
2025 | 0 | } |
2026 | | |
2027 | | llm_graph_params llama_context::graph_params( |
2028 | | llm_graph_result * res, |
2029 | | const llama_ubatch & ubatch, |
2030 | | const llama_memory_context_i * mctx, |
2031 | 0 | llm_graph_type gtype) const { |
2032 | 0 | return { |
2033 | 0 | /*.arch =*/ model.arch, |
2034 | 0 | /*.hparams =*/ model.hparams, |
2035 | 0 | /*.cparams =*/ cparams, |
2036 | 0 | /*.ubatch =*/ ubatch, |
2037 | 0 | /*.gtype =*/ gtype, |
2038 | 0 | /*.sched =*/ sched.get(), |
2039 | 0 | /*.backend_cpu =*/ backend_cpu, |
2040 | 0 | /*.cvec =*/ &cvec, |
2041 | 0 | /*.loras =*/ &loras, |
2042 | 0 | /*.mctx =*/ mctx, |
2043 | 0 | /*.cross =*/ &cross, |
2044 | 0 | /*.samplers =*/ sampling.samplers, |
2045 | 0 | /*.n_outputs =*/ n_outputs, |
2046 | 0 | /*.cb =*/ graph_get_cb(), |
2047 | 0 | /*.res =*/ res, |
2048 | 0 | }; |
2049 | 0 | } |
2050 | | |
2051 | | ggml_status llama_context::graph_compute( |
2052 | | ggml_cgraph * gf, |
2053 | 0 | bool batched) { |
2054 | 0 | int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; |
2055 | 0 | ggml_threadpool_t tp = batched ? threadpool_batch : threadpool; |
2056 | |
|
2057 | 0 | if (backend_cpu != nullptr) { |
2058 | 0 | auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu)); |
2059 | 0 | auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"); |
2060 | 0 | if (set_threadpool_fn) { |
2061 | 0 | set_threadpool_fn(backend_cpu, tp); |
2062 | 0 | } |
2063 | 0 | } |
2064 | | |
2065 | | // set the number of threads for all the backends |
2066 | 0 | for (const auto & set_n_threads_fn : set_n_threads_fns) { |
2067 | 0 | set_n_threads_fn.second(set_n_threads_fn.first, n_threads); |
2068 | 0 | } |
2069 | |
|
2070 | 0 | auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf); |
2071 | 0 | if (status != GGML_STATUS_SUCCESS) { |
2072 | 0 | LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); |
2073 | 0 | } |
2074 | | |
2075 | | // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched)); |
2076 | |
|
2077 | 0 | return status; |
2078 | 0 | } |
2079 | | |
2080 | 0 | llm_graph_cb llama_context::graph_get_cb() const { |
2081 | 0 | return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { |
2082 | 0 | if (il >= 0) { |
2083 | 0 | ggml_format_name(cur, "%s-%d", name, il); |
2084 | 0 | } else { |
2085 | 0 | ggml_set_name(cur, name); |
2086 | 0 | } |
2087 | |
|
2088 | 0 | if (!cparams.offload_kqv) { |
2089 | 0 | if (strcmp(name, "kqv_merged_cont") == 0) { |
2090 | | // all nodes between the KV store and the attention output are run on the CPU |
2091 | 0 | ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); |
2092 | 0 | } |
2093 | 0 | } |
2094 | | |
2095 | | // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends |
2096 | | // FIXME: fix in ggml_backend_sched |
2097 | 0 | const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; |
2098 | 0 | if (ubatch.n_tokens < 32 || full_offload) { |
2099 | 0 | if (il != -1 && strcmp(name, "norm") == 0) { |
2100 | 0 | const auto & dev_layer = model.dev_layer(il); |
2101 | 0 | for (const auto & backend : backends) { |
2102 | 0 | if (ggml_backend_get_device(backend.get()) == dev_layer) { |
2103 | 0 | if (ggml_backend_supports_op(backend.get(), cur)) { |
2104 | 0 | ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get()); |
2105 | 0 | } |
2106 | 0 | } |
2107 | 0 | } |
2108 | 0 | } |
2109 | 0 | } |
2110 | 0 | }; |
2111 | 0 | } |
2112 | | |
2113 | | // |
2114 | | // state save/load |
2115 | | // |
2116 | | |
2117 | | class llama_io_write_dummy : public llama_io_write_i { |
2118 | | public: |
2119 | 0 | llama_io_write_dummy() = default; |
2120 | | |
2121 | 0 | void write(const void * /* src */, size_t size) override { |
2122 | 0 | size_written += size; |
2123 | 0 | } |
2124 | | |
2125 | 0 | void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { |
2126 | 0 | size_written += size; |
2127 | 0 | } |
2128 | | |
2129 | 0 | size_t n_bytes() override { |
2130 | 0 | return size_written; |
2131 | 0 | } |
2132 | | |
2133 | | private: |
2134 | | size_t size_written = 0; |
2135 | | }; |
2136 | | |
2137 | | class llama_io_write_buffer : public llama_io_write_i { |
2138 | | public: |
2139 | | llama_io_write_buffer( |
2140 | 0 | uint8_t * p, size_t len) : ptr(p), buf_size(len) {} |
2141 | | |
2142 | 0 | void write(const void * src, size_t size) override { |
2143 | 0 | if (size > buf_size) { |
2144 | 0 | throw std::runtime_error("unexpectedly reached end of buffer"); |
2145 | 0 | } |
2146 | 0 | memcpy(ptr, src, size); |
2147 | 0 | ptr += size; |
2148 | 0 | size_written += size; |
2149 | 0 | buf_size -= size; |
2150 | 0 | } |
2151 | | |
2152 | 0 | void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { |
2153 | 0 | if (size > buf_size) { |
2154 | 0 | throw std::runtime_error("unexpectedly reached end of buffer"); |
2155 | 0 | } |
2156 | 0 | ggml_backend_tensor_get(tensor, ptr, offset, size); |
2157 | 0 | ptr += size; |
2158 | 0 | size_written += size; |
2159 | 0 | buf_size -= size; |
2160 | 0 | } |
2161 | | |
2162 | 0 | size_t n_bytes() override { |
2163 | 0 | return size_written; |
2164 | 0 | } |
2165 | | |
2166 | | private: |
2167 | | uint8_t * ptr; |
2168 | | size_t buf_size = 0; |
2169 | | size_t size_written = 0; |
2170 | | }; |
2171 | | |
2172 | | class llama_io_read_buffer : public llama_io_read_i { |
2173 | | public: |
2174 | 0 | llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} |
2175 | | |
2176 | 0 | const uint8_t * read(size_t size) override { |
2177 | 0 | const uint8_t * base_ptr = ptr; |
2178 | 0 | if (size > buf_size) { |
2179 | 0 | throw std::runtime_error("unexpectedly reached end of buffer"); |
2180 | 0 | } |
2181 | 0 | ptr += size; |
2182 | 0 | size_read += size; |
2183 | 0 | buf_size -= size; |
2184 | 0 | return base_ptr; |
2185 | 0 | } |
2186 | | |
2187 | 0 | void read_to(void * dst, size_t size) override { |
2188 | 0 | memcpy(dst, read(size), size); |
2189 | 0 | } |
2190 | | |
2191 | 0 | size_t n_bytes() override { |
2192 | 0 | return size_read; |
2193 | 0 | } |
2194 | | |
2195 | | private: |
2196 | | const uint8_t * ptr; |
2197 | | size_t buf_size = 0; |
2198 | | size_t size_read = 0; |
2199 | | }; |
2200 | | |
2201 | | class llama_io_write_file : public llama_io_write_i { |
2202 | | public: |
2203 | 0 | llama_io_write_file(llama_file * f) : file(f) {} |
2204 | | |
2205 | 0 | void write(const void * src, size_t size) override { |
2206 | 0 | file->write_raw(src, size); |
2207 | 0 | size_written += size; |
2208 | 0 | } |
2209 | | |
2210 | 0 | void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { |
2211 | 0 | temp_buffer.resize(size); |
2212 | 0 | ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); |
2213 | 0 | write(temp_buffer.data(), temp_buffer.size()); |
2214 | 0 | } |
2215 | | |
2216 | 0 | size_t n_bytes() override { |
2217 | 0 | return size_written; |
2218 | 0 | } |
2219 | | |
2220 | | private: |
2221 | | llama_file * file; |
2222 | | size_t size_written = 0; |
2223 | | std::vector<uint8_t> temp_buffer; |
2224 | | }; |
2225 | | |
2226 | | class llama_io_read_file : public llama_io_read_i { |
2227 | | public: |
2228 | 0 | llama_io_read_file(llama_file * f) : file(f) {} |
2229 | | |
2230 | 0 | void read_to(void * dst, size_t size) override { |
2231 | 0 | file->read_raw(dst, size); |
2232 | 0 | size_read += size; |
2233 | 0 | } |
2234 | | |
2235 | 0 | const uint8_t * read(size_t size) override { |
2236 | 0 | temp_buffer.resize(size); |
2237 | 0 | read_to(temp_buffer.data(), size); |
2238 | 0 | return temp_buffer.data(); |
2239 | 0 | } |
2240 | | |
2241 | 0 | size_t n_bytes() override { |
2242 | 0 | return size_read; |
2243 | 0 | } |
2244 | | |
2245 | | private: |
2246 | | llama_file * file; |
2247 | | size_t size_read = 0; |
2248 | | std::vector<uint8_t> temp_buffer; |
2249 | | }; |
2250 | | |
2251 | 0 | size_t llama_context::state_get_size() { |
2252 | 0 | llama_io_write_dummy io; |
2253 | 0 | try { |
2254 | 0 | return state_write_data(io); |
2255 | 0 | } catch (const std::exception & err) { |
2256 | 0 | LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); |
2257 | 0 | return 0; |
2258 | 0 | } |
2259 | 0 | } |
2260 | | |
2261 | 0 | size_t llama_context::state_get_data(uint8_t * dst, size_t size) { |
2262 | 0 | llama_io_write_buffer io(dst, size); |
2263 | 0 | try { |
2264 | 0 | return state_write_data(io); |
2265 | 0 | } catch (const std::exception & err) { |
2266 | 0 | LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); |
2267 | 0 | return 0; |
2268 | 0 | } |
2269 | 0 | } |
2270 | | |
2271 | 0 | size_t llama_context::state_set_data(const uint8_t * src, size_t size) { |
2272 | 0 | llama_io_read_buffer io(src, size); |
2273 | 0 | try { |
2274 | 0 | return state_read_data(io); |
2275 | 0 | } catch (const std::exception & err) { |
2276 | 0 | LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); |
2277 | 0 | return 0; |
2278 | 0 | } |
2279 | 0 | } |
2280 | | |
2281 | 0 | size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) { |
2282 | 0 | llama_io_write_dummy io; |
2283 | 0 | try { |
2284 | 0 | return state_seq_write_data(io, seq_id, flags); |
2285 | 0 | } catch (const std::exception & err) { |
2286 | 0 | LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); |
2287 | 0 | return 0; |
2288 | 0 | } |
2289 | 0 | } |
2290 | | |
2291 | 0 | size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) { |
2292 | 0 | llama_io_write_buffer io(dst, size); |
2293 | 0 | try { |
2294 | 0 | return state_seq_write_data(io, seq_id, flags); |
2295 | 0 | } catch (const std::exception & err) { |
2296 | 0 | LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); |
2297 | 0 | return 0; |
2298 | 0 | } |
2299 | 0 | } |
2300 | | |
2301 | 0 | size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) { |
2302 | 0 | llama_io_read_buffer io(src, size); |
2303 | 0 | try { |
2304 | 0 | return state_seq_read_data(io, seq_id, flags); |
2305 | 0 | } catch (const std::exception & err) { |
2306 | 0 | LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); |
2307 | 0 | return 0; |
2308 | 0 | } |
2309 | 0 | } |
2310 | | |
2311 | 0 | bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
2312 | 0 | llama_file file(filepath, "rb"); |
2313 | | |
2314 | | // sanity checks |
2315 | 0 | { |
2316 | 0 | const uint32_t magic = file.read_u32(); |
2317 | 0 | const uint32_t version = file.read_u32(); |
2318 | |
|
2319 | 0 | if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { |
2320 | 0 | LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); |
2321 | 0 | return false; |
2322 | 0 | } |
2323 | 0 | } |
2324 | | |
2325 | | // load the prompt |
2326 | 0 | { |
2327 | 0 | const uint32_t n_token_count = file.read_u32(); |
2328 | |
|
2329 | 0 | if (n_token_count > n_token_capacity) { |
2330 | 0 | LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); |
2331 | 0 | return false; |
2332 | 0 | } |
2333 | | |
2334 | 0 | file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); |
2335 | 0 | *n_token_count_out = n_token_count; |
2336 | 0 | } |
2337 | | |
2338 | | // restore the context state |
2339 | 0 | { |
2340 | 0 | const size_t n_state_size_cur = file.size() - file.tell(); |
2341 | |
|
2342 | 0 | llama_io_read_file io( &file); |
2343 | 0 | const size_t n_read = state_read_data(io); |
2344 | |
|
2345 | 0 | if (n_read != n_state_size_cur) { |
2346 | 0 | LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); |
2347 | 0 | return false; |
2348 | 0 | } |
2349 | 0 | } |
2350 | | |
2351 | 0 | return true; |
2352 | 0 | } |
2353 | | |
2354 | 0 | bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) { |
2355 | 0 | llama_file file(filepath, "wb"); |
2356 | |
|
2357 | 0 | file.write_u32(LLAMA_SESSION_MAGIC); |
2358 | 0 | file.write_u32(LLAMA_SESSION_VERSION); |
2359 | | |
2360 | | // save the prompt |
2361 | 0 | file.write_u32((uint32_t) n_token_count); |
2362 | 0 | file.write_raw(tokens, sizeof(llama_token) * n_token_count); |
2363 | | |
2364 | | // save the context state using stream saving |
2365 | 0 | llama_io_write_file io(&file); |
2366 | 0 | state_write_data(io); |
2367 | |
|
2368 | 0 | return true; |
2369 | 0 | } |
2370 | | |
2371 | 0 | size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
2372 | 0 | llama_file file(filepath, "rb"); |
2373 | | |
2374 | | // version checks |
2375 | 0 | { |
2376 | 0 | const uint32_t magic = file.read_u32(); |
2377 | 0 | const uint32_t version = file.read_u32(); |
2378 | |
|
2379 | 0 | if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { |
2380 | 0 | LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); |
2381 | 0 | return 0; |
2382 | 0 | } |
2383 | 0 | } |
2384 | | |
2385 | | // load the prompt |
2386 | 0 | { |
2387 | 0 | const uint32_t n_token_count = file.read_u32(); |
2388 | |
|
2389 | 0 | if (n_token_count > n_token_capacity) { |
2390 | 0 | LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); |
2391 | 0 | return 0; |
2392 | 0 | } |
2393 | | |
2394 | 0 | file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); |
2395 | 0 | *n_token_count_out = n_token_count; |
2396 | 0 | } |
2397 | | |
2398 | | // restore the context state |
2399 | 0 | { |
2400 | 0 | const size_t state_size = file.size() - file.tell(); |
2401 | 0 | llama_io_read_file io(&file); |
2402 | 0 | const size_t nread = state_seq_read_data(io, seq_id, 0); |
2403 | 0 | if (!nread) { |
2404 | 0 | LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); |
2405 | 0 | return 0; |
2406 | 0 | } |
2407 | 0 | GGML_ASSERT(nread <= state_size); |
2408 | 0 | GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); |
2409 | 0 | } |
2410 | | |
2411 | 0 | return file.tell(); |
2412 | 0 | } |
2413 | | |
2414 | 0 | size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) { |
2415 | 0 | llama_file file(filepath, "wb"); |
2416 | |
|
2417 | 0 | file.write_u32(LLAMA_STATE_SEQ_MAGIC); |
2418 | 0 | file.write_u32(LLAMA_STATE_SEQ_VERSION); |
2419 | | |
2420 | | // save the prompt |
2421 | 0 | file.write_u32((uint32_t) n_token_count); |
2422 | 0 | file.write_raw(tokens, sizeof(llama_token) * n_token_count); |
2423 | | |
2424 | | // save the context state using stream saving |
2425 | 0 | llama_io_write_file io(&file); |
2426 | 0 | state_seq_write_data(io, seq_id, 0); |
2427 | |
|
2428 | 0 | const size_t res = file.tell(); |
2429 | 0 | GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); |
2430 | |
|
2431 | 0 | return res; |
2432 | 0 | } |
2433 | | |
2434 | 0 | size_t llama_context::state_write_data(llama_io_write_i & io) { |
2435 | 0 | LLAMA_LOG_DEBUG("%s: writing state\n", __func__); |
2436 | | |
2437 | | // write model info |
2438 | 0 | { |
2439 | 0 | LLAMA_LOG_DEBUG("%s: - writing model info\n", __func__); |
2440 | |
|
2441 | 0 | const std::string arch_str = llm_arch_name(model.arch); |
2442 | 0 | io.write_string(arch_str); |
2443 | | // TODO: add more model-specific info which should prevent loading the session file if not identical |
2444 | 0 | } |
2445 | | |
2446 | | // write output ids |
2447 | 0 | { |
2448 | 0 | LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); |
2449 | |
|
2450 | 0 | const auto n_outputs = this->n_outputs; |
2451 | 0 | const auto & output_ids = this->output_ids; |
2452 | |
|
2453 | 0 | std::vector<int32_t> w_output_pos; |
2454 | |
|
2455 | 0 | w_output_pos.resize(n_outputs); |
2456 | | |
2457 | | // build a more compact representation of the output ids |
2458 | 0 | for (size_t i = 0; i < n_batch(); ++i) { |
2459 | | // map an output id to a position in the batch |
2460 | 0 | int64_t pos = output_ids[i]; |
2461 | 0 | if (pos >= 0) { |
2462 | 0 | GGML_ASSERT(pos < n_outputs); |
2463 | 0 | w_output_pos[pos] = i; |
2464 | 0 | } |
2465 | 0 | } |
2466 | |
|
2467 | 0 | io.write(&n_outputs, sizeof(n_outputs)); |
2468 | |
|
2469 | 0 | if (n_outputs) { |
2470 | 0 | io.write(w_output_pos.data(), n_outputs * sizeof(int32_t)); |
2471 | 0 | } |
2472 | 0 | } |
2473 | | |
2474 | | // write logits |
2475 | 0 | { |
2476 | 0 | LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); |
2477 | |
|
2478 | 0 | const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens()); |
2479 | |
|
2480 | 0 | io.write(&logits_size, sizeof(logits_size)); |
2481 | |
|
2482 | 0 | if (logits_size) { |
2483 | 0 | io.write(logits, logits_size * sizeof(float)); |
2484 | 0 | } |
2485 | 0 | } |
2486 | | |
2487 | | // write embeddings |
2488 | 0 | { |
2489 | 0 | LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__); |
2490 | |
|
2491 | 0 | const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd); |
2492 | |
|
2493 | 0 | io.write(&embd_size, sizeof(embd_size)); |
2494 | |
|
2495 | 0 | if (embd_size) { |
2496 | 0 | io.write(embd, embd_size * sizeof(float)); |
2497 | 0 | } |
2498 | 0 | } |
2499 | | |
2500 | | // TODO: handle sampling buffers and samplers state ? |
2501 | | // https://github.com/ggml-org/llama.cpp/pull/17004 |
2502 | |
|
2503 | 0 | if (memory != nullptr) { |
2504 | 0 | LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); |
2505 | 0 | memory->state_write(io); |
2506 | 0 | } |
2507 | |
|
2508 | 0 | return io.n_bytes(); |
2509 | 0 | } |
2510 | | |
2511 | 0 | size_t llama_context::state_read_data(llama_io_read_i & io) { |
2512 | 0 | LLAMA_LOG_DEBUG("%s: reading state\n", __func__); |
2513 | | |
2514 | | // read model info |
2515 | 0 | { |
2516 | 0 | LLAMA_LOG_DEBUG("%s: - reading model info\n", __func__); |
2517 | |
|
2518 | 0 | const std::string cur_arch_str = llm_arch_name(model.arch); |
2519 | |
|
2520 | 0 | std::string arch_str; |
2521 | 0 | io.read_string(arch_str); |
2522 | 0 | if (cur_arch_str != arch_str) { |
2523 | 0 | throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); |
2524 | 0 | } |
2525 | | // TODO: add more info which needs to be identical but which is not verified otherwise |
2526 | 0 | } |
2527 | | |
2528 | | // read output ids |
2529 | 0 | { |
2530 | 0 | LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__); |
2531 | |
|
2532 | 0 | auto n_outputs = this->n_outputs; |
2533 | 0 | io.read_to(&n_outputs, sizeof(n_outputs)); |
2534 | | |
2535 | | // Create a dummy batch for state loading. |
2536 | 0 | llama_batch dummy_batch = {}; |
2537 | 0 | dummy_batch.n_tokens = 0; |
2538 | 0 | if (n_outputs > output_reserve(n_outputs, dummy_batch)) { |
2539 | 0 | throw std::runtime_error("could not reserve outputs"); |
2540 | 0 | } |
2541 | | |
2542 | 0 | std::vector<int32_t> output_pos; |
2543 | |
|
2544 | 0 | if (n_outputs) { |
2545 | 0 | output_pos.resize(n_outputs); |
2546 | 0 | io.read_to(output_pos.data(), n_outputs * sizeof(int32_t)); |
2547 | |
|
2548 | 0 | for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { |
2549 | 0 | int32_t id = output_pos[i]; |
2550 | 0 | if ((uint32_t) id >= n_batch()) { |
2551 | 0 | throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch())); |
2552 | 0 | } |
2553 | 0 | this->output_ids[id] = i; |
2554 | 0 | } |
2555 | | |
2556 | 0 | this->n_outputs = n_outputs; |
2557 | 0 | } |
2558 | 0 | } |
2559 | | |
2560 | | // read logits |
2561 | 0 | { |
2562 | 0 | LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__); |
2563 | |
|
2564 | 0 | uint64_t logits_size; |
2565 | 0 | io.read_to(&logits_size, sizeof(logits_size)); |
2566 | |
|
2567 | 0 | if (this->logits_size < logits_size) { |
2568 | 0 | throw std::runtime_error("logits buffer too small"); |
2569 | 0 | } |
2570 | | |
2571 | 0 | if (logits_size) { |
2572 | 0 | io.read_to(this->logits, logits_size * sizeof(float)); |
2573 | 0 | } |
2574 | 0 | } |
2575 | | |
2576 | | // read embeddings |
2577 | 0 | { |
2578 | 0 | LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__); |
2579 | |
|
2580 | 0 | uint64_t embd_size; |
2581 | 0 | io.read_to(&embd_size, sizeof(embd_size)); |
2582 | |
|
2583 | 0 | if (this->embd_size < embd_size) { |
2584 | 0 | throw std::runtime_error("embeddings buffer too small"); |
2585 | 0 | } |
2586 | | |
2587 | 0 | if (embd_size) { |
2588 | 0 | io.read_to(this->embd, embd_size * sizeof(float)); |
2589 | 0 | } |
2590 | 0 | } |
2591 | | |
2592 | | // TODO: handle sampling buffers and samplers state ? |
2593 | | // https://github.com/ggml-org/llama.cpp/pull/17004 |
2594 | | |
2595 | 0 | if (memory) { |
2596 | 0 | LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); |
2597 | |
|
2598 | 0 | memory->state_read(io); |
2599 | 0 | } |
2600 | |
|
2601 | 0 | return io.n_bytes(); |
2602 | 0 | } |
2603 | | |
2604 | 0 | size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2605 | 0 | GGML_UNUSED(seq_id); |
2606 | |
|
2607 | 0 | if (memory) { |
2608 | 0 | memory->state_write(io, seq_id, flags); |
2609 | 0 | } |
2610 | |
|
2611 | 0 | return io.n_bytes(); |
2612 | 0 | } |
2613 | | |
2614 | 0 | size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2615 | 0 | GGML_UNUSED(seq_id); |
2616 | |
|
2617 | 0 | if (memory) { |
2618 | 0 | memory->state_read(io, seq_id, flags); |
2619 | 0 | } |
2620 | |
|
2621 | 0 | return io.n_bytes(); |
2622 | 0 | } |
2623 | | |
2624 | | // |
2625 | | // perf |
2626 | | // |
2627 | | |
2628 | 0 | llama_perf_context_data llama_context::perf_get_data() const { |
2629 | 0 | llama_perf_context_data data = {}; |
2630 | |
|
2631 | 0 | data.t_start_ms = 1e-3 * t_start_us; |
2632 | 0 | data.t_load_ms = 1e-3 * t_load_us; |
2633 | 0 | data.t_p_eval_ms = 1e-3 * t_p_eval_us; |
2634 | 0 | data.t_eval_ms = 1e-3 * t_eval_us; |
2635 | 0 | data.n_p_eval = std::max(1, n_p_eval); |
2636 | 0 | data.n_eval = std::max(1, n_eval); |
2637 | 0 | data.n_reused = std::max(0, n_reused); |
2638 | |
|
2639 | 0 | return data; |
2640 | 0 | } |
2641 | | |
2642 | 0 | void llama_context::perf_reset() { |
2643 | 0 | t_start_us = ggml_time_us(); |
2644 | 0 | t_eval_us = n_eval = 0; |
2645 | 0 | t_p_eval_us = n_p_eval = 0; |
2646 | 0 | n_reused = 0; |
2647 | 0 | } |
2648 | | |
2649 | 0 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const { |
2650 | 0 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret; |
2651 | 0 | for (const auto & [buft, size] : model.memory_breakdown()) { |
2652 | 0 | ret[buft].model += size; |
2653 | 0 | } |
2654 | 0 | if (memory) { |
2655 | 0 | for (const auto & [buft, size] : memory->memory_breakdown()) { |
2656 | 0 | ret[buft].context += size; |
2657 | 0 | } |
2658 | 0 | } |
2659 | 0 | if (model.hparams.no_alloc) { |
2660 | 0 | for (size_t i = 0; i < backends.size(); ++i) { |
2661 | 0 | ggml_backend_t backend = backends[i].get(); |
2662 | 0 | ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend); |
2663 | 0 | ret[buft].compute += backend_buf_exp_size[i]; |
2664 | 0 | } |
2665 | 0 | } else { |
2666 | 0 | for (const auto & backend_ptr : backends) { |
2667 | 0 | ggml_backend_t backend = backend_ptr.get(); |
2668 | 0 | ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend); |
2669 | 0 | ret[buft].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend); |
2670 | 0 | } |
2671 | 0 | } |
2672 | 0 | return ret; |
2673 | 0 | } |
2674 | | |
2675 | | // |
2676 | | // training |
2677 | | // |
2678 | | |
2679 | 0 | static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) { |
2680 | 0 | if (!tensor || tensor->type != GGML_TYPE_F32) { |
2681 | 0 | return; |
2682 | 0 | } |
2683 | 0 | if (!param_filter(tensor, userdata)) { |
2684 | 0 | return; |
2685 | 0 | } |
2686 | 0 | if (strcmp(tensor->name, "token_embd.weight") == 0) { |
2687 | 0 | return; // FIXME |
2688 | 0 | } |
2689 | 0 | if (strcmp(tensor->name, "rope_freqs.weight") == 0) { |
2690 | 0 | return; // FIXME |
2691 | 0 | } |
2692 | 0 | ggml_set_param(tensor); |
2693 | 0 | } |
2694 | | |
2695 | 0 | void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) { |
2696 | 0 | GGML_ASSERT(!opt_ctx); |
2697 | 0 | model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx(); |
2698 | 0 | const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train); |
2699 | 0 | const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); |
2700 | 0 | GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0); |
2701 | 0 | GGML_ASSERT(n_batch % n_ubatch == 0); |
2702 | |
|
2703 | 0 | ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY); |
2704 | 0 | opt_params.opt_period = n_batch / n_ubatch; |
2705 | 0 | opt_params.get_opt_pars = lopt_params.get_opt_pars; |
2706 | 0 | opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud; |
2707 | 0 | opt_params.optimizer = lopt_params.optimizer_type; |
2708 | 0 | opt_ctx = ggml_opt_init(opt_params); |
2709 | |
|
2710 | 0 | llama_opt_param_filter param_filter = lopt_params.param_filter; |
2711 | 0 | void * param_filter_ud = lopt_params.param_filter_ud; |
2712 | | |
2713 | | //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME |
2714 | 0 | llama_set_param(model->type_embd, param_filter, param_filter_ud); |
2715 | 0 | llama_set_param(model->pos_embd, param_filter, param_filter_ud); |
2716 | 0 | llama_set_param(model->tok_norm, param_filter, param_filter_ud); |
2717 | 0 | llama_set_param(model->tok_norm_b, param_filter, param_filter_ud); |
2718 | 0 | llama_set_param(model->output_norm, param_filter, param_filter_ud); |
2719 | 0 | llama_set_param(model->output_norm_b, param_filter, param_filter_ud); |
2720 | 0 | llama_set_param(model->output, param_filter, param_filter_ud); |
2721 | 0 | llama_set_param(model->output_b, param_filter, param_filter_ud); |
2722 | 0 | llama_set_param(model->output_norm_enc, param_filter, param_filter_ud); |
2723 | 0 | llama_set_param(model->cls, param_filter, param_filter_ud); |
2724 | 0 | llama_set_param(model->cls_b, param_filter, param_filter_ud); |
2725 | 0 | llama_set_param(model->cls_out, param_filter, param_filter_ud); |
2726 | 0 | llama_set_param(model->cls_out_b, param_filter, param_filter_ud); |
2727 | |
|
2728 | 0 | for (struct llama_layer & layer : model->layers) { |
2729 | 0 | for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { |
2730 | 0 | llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud); |
2731 | 0 | } |
2732 | 0 | } |
2733 | 0 | } |
2734 | | |
2735 | | void llama_context::opt_epoch_iter( |
2736 | | ggml_opt_dataset_t dataset, |
2737 | | ggml_opt_result_t result, |
2738 | | const std::vector<llama_token> & tokens, |
2739 | | const std::vector<llama_token> & labels_sparse, |
2740 | | llama_batch & batch, |
2741 | | ggml_opt_epoch_callback callback, |
2742 | | bool train, |
2743 | | int64_t idata_in_loop, |
2744 | | int64_t ndata_in_loop, |
2745 | 0 | int64_t t_loop_start) { |
2746 | 0 | GGML_ASSERT(opt_ctx); |
2747 | 0 | const uint32_t n_ctx = llama_model_n_ctx_train(&model); |
2748 | 0 | const uint32_t n_batch = std::min(this->n_batch(), n_ctx); |
2749 | 0 | const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); |
2750 | |
|
2751 | 0 | memory->clear(true); |
2752 | |
|
2753 | 0 | for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { |
2754 | 0 | batch.n_tokens = n_batch; |
2755 | 0 | for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) { |
2756 | 0 | batch.token [pos_batch] = tokens[pos_ctx + pos_batch]; |
2757 | 0 | batch.pos [pos_batch] = pos_ctx + pos_batch; |
2758 | 0 | batch.n_seq_id[pos_batch] = 1; |
2759 | 0 | batch.seq_id [pos_batch][0] = 0; |
2760 | 0 | batch.logits [pos_batch] = true; |
2761 | 0 | } |
2762 | |
|
2763 | 0 | if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { |
2764 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); |
2765 | 0 | return; |
2766 | 0 | } |
2767 | | |
2768 | 0 | const uint32_t n_tokens_all = balloc->get_n_tokens(); |
2769 | |
|
2770 | 0 | n_queued_tokens += n_tokens_all; |
2771 | |
|
2772 | 0 | embd_seq.clear(); |
2773 | |
|
2774 | 0 | uint32_t n_outputs_all = n_tokens_all; |
2775 | |
|
2776 | 0 | auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true); |
2777 | 0 | if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { |
2778 | 0 | LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); |
2779 | 0 | break; |
2780 | 0 | } |
2781 | | |
2782 | | // reserve output buffer |
2783 | 0 | if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { |
2784 | 0 | LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); |
2785 | 0 | GGML_ABORT("TODO: handle this error"); |
2786 | 0 | }; |
2787 | |
|
2788 | 0 | uint32_t pos_batch = 0; |
2789 | 0 | do { |
2790 | 0 | const auto & ubatch = mctx->get_ubatch(); |
2791 | |
|
2792 | 0 | n_outputs = ubatch.n_tokens; |
2793 | |
|
2794 | 0 | if (!mctx->apply()) { |
2795 | 0 | LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__); |
2796 | 0 | break; |
2797 | 0 | } |
2798 | | |
2799 | 0 | auto * res = gf_res_prev.get(); |
2800 | |
|
2801 | 0 | const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); |
2802 | |
|
2803 | 0 | res->reset(); |
2804 | |
|
2805 | 0 | auto * gf = model.build_graph(gparams); |
2806 | |
|
2807 | 0 | struct ggml_context * ctx_compute_opt; |
2808 | 0 | { |
2809 | 0 | const size_t size_gf = ggml_graph_size(gf); |
2810 | 0 | const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true); |
2811 | 0 | struct ggml_init_params params = { |
2812 | 0 | /*.mem_size =*/ size_meta, |
2813 | 0 | /*.mem_buffer =*/ nullptr, |
2814 | 0 | /*.no_alloc =*/ true, |
2815 | 0 | }; |
2816 | 0 | ctx_compute_opt = ggml_init(params); |
2817 | 0 | } |
2818 | 0 | ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); |
2819 | 0 | ggml_opt_alloc(opt_ctx, train); |
2820 | |
|
2821 | 0 | res->set_inputs(&ubatch); |
2822 | 0 | { |
2823 | 0 | struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); |
2824 | 0 | GGML_ASSERT(labels->ne[1] == n_ubatch); |
2825 | 0 | ggml_set_zero(labels); |
2826 | 0 | const float onef = 1.0f; |
2827 | 0 | for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) { |
2828 | 0 | const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch; |
2829 | 0 | GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]); |
2830 | 0 | ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float)); |
2831 | 0 | } |
2832 | 0 | } |
2833 | 0 | ggml_opt_eval(opt_ctx, result); |
2834 | 0 | if (callback) { |
2835 | 0 | callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start); |
2836 | 0 | } |
2837 | 0 | ggml_free(ctx_compute_opt); |
2838 | |
|
2839 | 0 | pos_batch += ubatch.n_tokens; |
2840 | 0 | } while (mctx->next()); |
2841 | 0 | } |
2842 | 0 | } |
2843 | | |
2844 | | void llama_context::opt_epoch( |
2845 | | ggml_opt_dataset_t dataset, |
2846 | | ggml_opt_result_t result_train, |
2847 | | ggml_opt_result_t result_eval, |
2848 | | int64_t idata_split, |
2849 | | ggml_opt_epoch_callback callback_train, |
2850 | 0 | ggml_opt_epoch_callback callback_eval) { |
2851 | 0 | const uint32_t n_ctx = this->n_ctx(); |
2852 | 0 | const uint32_t n_batch = std::min(cparams.n_batch, n_ctx); |
2853 | 0 | const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch); |
2854 | 0 | const int64_t ndata = ggml_opt_dataset_ndata(dataset); |
2855 | |
|
2856 | 0 | GGML_ASSERT(idata_split >= 0); |
2857 | 0 | GGML_ASSERT(idata_split <= ndata); |
2858 | |
|
2859 | 0 | const uint32_t ubatch_per_ctx = n_ctx / n_ubatch; |
2860 | |
|
2861 | 0 | struct llama_batch batch = llama_batch_init(n_batch, 0, 1); |
2862 | 0 | std::vector<llama_token> tokens(n_ctx); |
2863 | 0 | std::vector<llama_token> labels_sparse(n_ctx); |
2864 | |
|
2865 | 0 | int64_t idata = 0; |
2866 | |
|
2867 | 0 | int64_t t_loop_start = ggml_time_us(); |
2868 | 0 | int64_t ndata_in_loop = idata_split*ubatch_per_ctx; |
2869 | 0 | for (; idata < idata_split; ++idata) { |
2870 | 0 | constexpr bool train = true; |
2871 | 0 | const int64_t idata_in_loop = idata*ubatch_per_ctx; |
2872 | |
|
2873 | 0 | ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); |
2874 | 0 | opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch, |
2875 | 0 | callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start); |
2876 | 0 | } |
2877 | |
|
2878 | 0 | t_loop_start = ggml_time_us(); |
2879 | 0 | ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx; |
2880 | 0 | for (; idata < ndata; ++idata) { |
2881 | 0 | constexpr bool train = false; |
2882 | 0 | const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx; |
2883 | |
|
2884 | 0 | ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); |
2885 | 0 | opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch, |
2886 | 0 | callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start); |
2887 | 0 | } |
2888 | |
|
2889 | 0 | llama_batch_free(batch); |
2890 | 0 | } |
2891 | | |
2892 | | // |
2893 | | // interface implementation |
2894 | | // |
2895 | | |
2896 | 0 | llama_context_params llama_context_default_params() { |
2897 | 0 | llama_context_params result = { |
2898 | 0 | /*.n_ctx =*/ 512, |
2899 | 0 | /*.n_batch =*/ 2048, |
2900 | 0 | /*.n_ubatch =*/ 512, |
2901 | 0 | /*.n_seq_max =*/ 1, |
2902 | 0 | /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default |
2903 | 0 | /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, |
2904 | 0 | /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, |
2905 | 0 | /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, |
2906 | 0 | /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, |
2907 | 0 | /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO, |
2908 | 0 | /*.rope_freq_base =*/ 0.0f, |
2909 | 0 | /*.rope_freq_scale =*/ 0.0f, |
2910 | 0 | /*.yarn_ext_factor =*/ -1.0f, |
2911 | 0 | /*.yarn_attn_factor =*/ -1.0f, |
2912 | 0 | /*.yarn_beta_fast =*/ -1.0f, |
2913 | 0 | /*.yarn_beta_slow =*/ -1.0f, |
2914 | 0 | /*.yarn_orig_ctx =*/ 0, |
2915 | 0 | /*.defrag_thold =*/ -1.0f, |
2916 | 0 | /*.cb_eval =*/ nullptr, |
2917 | 0 | /*.cb_eval_user_data =*/ nullptr, |
2918 | 0 | /*.type_k =*/ GGML_TYPE_F16, |
2919 | 0 | /*.type_v =*/ GGML_TYPE_F16, |
2920 | 0 | /*.abort_callback =*/ nullptr, |
2921 | 0 | /*.abort_callback_data =*/ nullptr, |
2922 | 0 | /*.embeddings =*/ false, |
2923 | 0 | /*.offload_kqv =*/ true, |
2924 | 0 | /*.no_perf =*/ true, |
2925 | 0 | /*.op_offload =*/ true, |
2926 | 0 | /*.swa_full =*/ true, |
2927 | 0 | /*.kv_unified =*/ false, |
2928 | 0 | /*.sampler =*/ nullptr, |
2929 | 0 | /*.n_sampler =*/ 0, |
2930 | 0 | }; |
2931 | |
|
2932 | 0 | return result; |
2933 | 0 | } |
2934 | | |
2935 | | llama_context * llama_init_from_model( |
2936 | | llama_model * model, |
2937 | 0 | llama_context_params params) { |
2938 | 0 | if (!model) { |
2939 | 0 | LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__); |
2940 | 0 | return nullptr; |
2941 | 0 | } |
2942 | | |
2943 | 0 | if (params.n_batch == 0 && params.n_ubatch == 0) { |
2944 | 0 | LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__); |
2945 | 0 | return nullptr; |
2946 | 0 | } |
2947 | | |
2948 | 0 | if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) { |
2949 | 0 | LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__); |
2950 | 0 | return nullptr; |
2951 | 0 | } |
2952 | | |
2953 | 0 | if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) { |
2954 | 0 | LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); |
2955 | 0 | params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; |
2956 | 0 | } |
2957 | |
|
2958 | 0 | if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { |
2959 | 0 | const uint32_t blck_size = ggml_blck_size(params.type_k); |
2960 | 0 | if (model->hparams.n_embd_head_k % blck_size != 0) { |
2961 | 0 | LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", |
2962 | 0 | __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k); |
2963 | 0 | return nullptr; |
2964 | 0 | } |
2965 | 0 | } |
2966 | | |
2967 | 0 | if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { |
2968 | 0 | const uint32_t blck_size = ggml_blck_size(params.type_v); |
2969 | 0 | if (model->hparams.n_embd_head_v % blck_size != 0) { |
2970 | 0 | LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n", |
2971 | 0 | __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v); |
2972 | 0 | return nullptr; |
2973 | 0 | } |
2974 | 0 | } |
2975 | | |
2976 | 0 | if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) { |
2977 | 0 | LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); |
2978 | 0 | return nullptr; |
2979 | 0 | } |
2980 | | |
2981 | 0 | if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED && |
2982 | 0 | params.pooling_type != model->hparams.pooling_type) { |
2983 | | //user-specified pooling-type is different from the model default |
2984 | 0 | LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__, |
2985 | 0 | model->hparams.pooling_type, params.pooling_type); |
2986 | 0 | } |
2987 | |
|
2988 | 0 | try { |
2989 | 0 | auto * ctx = new llama_context(*model, params); |
2990 | 0 | return ctx; |
2991 | 0 | } catch (const std::exception & err) { |
2992 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what()); |
2993 | 0 | } |
2994 | | |
2995 | 0 | return nullptr; |
2996 | 0 | } |
2997 | | |
2998 | | // deprecated |
2999 | | llama_context * llama_new_context_with_model( |
3000 | | llama_model * model, |
3001 | 0 | llama_context_params params) { |
3002 | 0 | return llama_init_from_model(model, params); |
3003 | 0 | } |
3004 | | |
3005 | 0 | void llama_free(llama_context * ctx) { |
3006 | 0 | delete ctx; |
3007 | 0 | } |
3008 | | |
3009 | 0 | uint32_t llama_n_ctx(const llama_context * ctx) { |
3010 | 0 | return ctx->n_ctx(); |
3011 | 0 | } |
3012 | | |
3013 | 0 | uint32_t llama_n_ctx_seq(const llama_context * ctx) { |
3014 | 0 | return ctx->n_ctx_seq(); |
3015 | 0 | } |
3016 | | |
3017 | 0 | uint32_t llama_n_batch(const llama_context * ctx) { |
3018 | 0 | return ctx->n_batch(); |
3019 | 0 | } |
3020 | | |
3021 | 0 | uint32_t llama_n_ubatch(const llama_context * ctx) { |
3022 | 0 | return ctx->n_ubatch(); |
3023 | 0 | } |
3024 | | |
3025 | 0 | uint32_t llama_n_seq_max(const llama_context * ctx) { |
3026 | 0 | return ctx->n_seq_max(); |
3027 | 0 | } |
3028 | | |
3029 | 0 | const llama_model * llama_get_model(const llama_context * ctx) { |
3030 | 0 | return &ctx->get_model(); |
3031 | 0 | } |
3032 | | |
3033 | 0 | enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { |
3034 | 0 | return ctx->pooling_type(); |
3035 | 0 | } |
3036 | | |
3037 | | void llama_attach_threadpool( |
3038 | | llama_context * ctx, |
3039 | | ggml_threadpool_t threadpool, |
3040 | 0 | ggml_threadpool_t threadpool_batch) { |
3041 | 0 | ctx->attach_threadpool(threadpool, threadpool_batch); |
3042 | 0 | } |
3043 | | |
3044 | 0 | void llama_detach_threadpool(llama_context * ctx) { |
3045 | 0 | ctx->detach_threadpool(); |
3046 | 0 | } |
3047 | | |
3048 | 0 | void llama_set_n_threads(llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) { |
3049 | 0 | ctx->set_n_threads(n_threads, n_threads_batch); |
3050 | 0 | } |
3051 | | |
3052 | 0 | int32_t llama_n_threads(llama_context * ctx) { |
3053 | 0 | return ctx->n_threads(); |
3054 | 0 | } |
3055 | | |
3056 | 0 | int32_t llama_n_threads_batch(llama_context * ctx) { |
3057 | 0 | return ctx->n_threads_batch(); |
3058 | 0 | } |
3059 | | |
3060 | 0 | void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { |
3061 | 0 | ctx->set_abort_callback(abort_callback, abort_callback_data); |
3062 | 0 | } |
3063 | | |
3064 | 0 | void llama_set_embeddings(llama_context * ctx, bool embeddings) { |
3065 | 0 | ctx->set_embeddings(embeddings); |
3066 | 0 | } |
3067 | | |
3068 | 0 | void llama_set_causal_attn(llama_context * ctx, bool causal_attn) { |
3069 | 0 | ctx->set_causal_attn(causal_attn); |
3070 | 0 | } |
3071 | | |
3072 | 0 | void llama_set_warmup(llama_context * ctx, bool warmup) { |
3073 | 0 | ctx->set_warmup(warmup); |
3074 | 0 | } |
3075 | | |
3076 | 0 | void llama_synchronize(llama_context * ctx) { |
3077 | 0 | ctx->synchronize(); |
3078 | 0 | } |
3079 | | |
3080 | 0 | float * llama_get_logits(llama_context * ctx) { |
3081 | 0 | ctx->synchronize(); |
3082 | |
|
3083 | 0 | return ctx->get_logits(); |
3084 | 0 | } |
3085 | | |
3086 | 0 | float * llama_get_logits_ith(llama_context * ctx, int32_t i) { |
3087 | 0 | ctx->synchronize(); |
3088 | |
|
3089 | 0 | float * res = nullptr; |
3090 | |
|
3091 | 0 | res = ctx->get_sampled_logits_ith(i); |
3092 | |
|
3093 | 0 | if (!res) { |
3094 | 0 | res = ctx->get_logits_ith(i); |
3095 | 0 | } |
3096 | |
|
3097 | 0 | return res; |
3098 | 0 | } |
3099 | | |
3100 | 0 | float * llama_get_embeddings(llama_context * ctx) { |
3101 | 0 | ctx->synchronize(); |
3102 | |
|
3103 | 0 | return ctx->get_embeddings(); |
3104 | 0 | } |
3105 | | |
3106 | 0 | float * llama_get_embeddings_ith(llama_context * ctx, int32_t i) { |
3107 | 0 | ctx->synchronize(); |
3108 | |
|
3109 | 0 | return ctx->get_embeddings_ith(i); |
3110 | 0 | } |
3111 | | |
3112 | 0 | float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { |
3113 | 0 | ctx->synchronize(); |
3114 | |
|
3115 | 0 | return ctx->get_embeddings_seq(seq_id); |
3116 | 0 | } |
3117 | | |
3118 | 0 | bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { |
3119 | 0 | return ctx->set_sampler(seq_id, smpl); |
3120 | 0 | } |
3121 | | |
3122 | 0 | llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) { |
3123 | 0 | ctx->synchronize(); |
3124 | |
|
3125 | 0 | return ctx->get_sampled_token_ith(i); |
3126 | 0 | } |
3127 | | |
3128 | 0 | float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) { |
3129 | 0 | ctx->synchronize(); |
3130 | |
|
3131 | 0 | return ctx->get_sampled_probs_ith(i); |
3132 | 0 | } |
3133 | | |
3134 | 0 | float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) { |
3135 | 0 | ctx->synchronize(); |
3136 | |
|
3137 | 0 | return ctx->get_sampled_logits_ith(i); |
3138 | 0 | } |
3139 | | |
3140 | 0 | llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) { |
3141 | 0 | ctx->synchronize(); |
3142 | |
|
3143 | 0 | return const_cast<llama_token *>(ctx->get_sampled_candidates_ith(i)); |
3144 | 0 | } |
3145 | | |
3146 | 0 | uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) { |
3147 | 0 | ctx->synchronize(); |
3148 | |
|
3149 | 0 | return static_cast<uint32_t>(ctx->get_sampled_candidates_count(i)); |
3150 | 0 | } |
3151 | | |
3152 | 0 | uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) { |
3153 | 0 | ctx->synchronize(); |
3154 | |
|
3155 | 0 | return static_cast<uint32_t>(ctx->get_sampled_logits_count(i)); |
3156 | 0 | } |
3157 | | |
3158 | 0 | uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { |
3159 | 0 | ctx->synchronize(); |
3160 | |
|
3161 | 0 | return static_cast<uint32_t>(ctx->get_sampled_probs_count(i)); |
3162 | 0 | } |
3163 | | |
3164 | | // llama adapter API |
3165 | | |
3166 | | int32_t llama_set_adapter_lora( |
3167 | | llama_context * ctx, |
3168 | | llama_adapter_lora * adapter, |
3169 | 0 | float scale) { |
3170 | 0 | ctx->set_adapter_lora(adapter, scale); |
3171 | |
|
3172 | 0 | return 0; |
3173 | 0 | } |
3174 | | |
3175 | | int32_t llama_rm_adapter_lora( |
3176 | | llama_context * ctx, |
3177 | 0 | llama_adapter_lora * adapter) { |
3178 | 0 | bool res = ctx->rm_adapter_lora(adapter); |
3179 | |
|
3180 | 0 | return res ? 0 : -1; |
3181 | 0 | } |
3182 | | |
3183 | 0 | void llama_clear_adapter_lora(llama_context * ctx) { |
3184 | 0 | ctx->clear_adapter_lora(); |
3185 | 0 | } |
3186 | | |
3187 | | int32_t llama_apply_adapter_cvec( |
3188 | | llama_context * ctx, |
3189 | | const float * data, |
3190 | | size_t len, |
3191 | | int32_t n_embd, |
3192 | | int32_t il_start, |
3193 | 0 | int32_t il_end) { |
3194 | 0 | bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end); |
3195 | |
|
3196 | 0 | return res ? 0 : -1; |
3197 | 0 | } |
3198 | | |
3199 | | // |
3200 | | // memory |
3201 | | // |
3202 | | |
3203 | 0 | llama_memory_t llama_get_memory(const struct llama_context * ctx) { |
3204 | 0 | return ctx->get_memory(); |
3205 | 0 | } |
3206 | | |
3207 | 0 | void llama_memory_clear(llama_memory_t mem, bool data) { |
3208 | 0 | if (!mem) { |
3209 | 0 | return; |
3210 | 0 | } |
3211 | | |
3212 | 0 | mem->clear(data); |
3213 | 0 | } |
3214 | | |
3215 | | bool llama_memory_seq_rm( |
3216 | | llama_memory_t mem, |
3217 | | llama_seq_id seq_id, |
3218 | | llama_pos p0, |
3219 | 0 | llama_pos p1) { |
3220 | 0 | if (!mem) { |
3221 | 0 | return true; |
3222 | 0 | } |
3223 | | |
3224 | 0 | return mem->seq_rm(seq_id, p0, p1); |
3225 | 0 | } |
3226 | | |
3227 | | void llama_memory_seq_cp( |
3228 | | llama_memory_t mem, |
3229 | | llama_seq_id seq_id_src, |
3230 | | llama_seq_id seq_id_dst, |
3231 | | llama_pos p0, |
3232 | 0 | llama_pos p1) { |
3233 | 0 | if (!mem) { |
3234 | 0 | return; |
3235 | 0 | } |
3236 | | |
3237 | 0 | mem->seq_cp(seq_id_src, seq_id_dst, p0, p1); |
3238 | 0 | } |
3239 | | |
3240 | | void llama_memory_seq_keep( |
3241 | | llama_memory_t mem, |
3242 | 0 | llama_seq_id seq_id) { |
3243 | 0 | if (!mem) { |
3244 | 0 | return; |
3245 | 0 | } |
3246 | | |
3247 | 0 | mem->seq_keep(seq_id); |
3248 | 0 | } |
3249 | | |
3250 | | void llama_memory_seq_add( |
3251 | | llama_memory_t mem, |
3252 | | llama_seq_id seq_id, |
3253 | | llama_pos p0, |
3254 | | llama_pos p1, |
3255 | 0 | llama_pos delta) { |
3256 | 0 | if (!mem) { |
3257 | 0 | return; |
3258 | 0 | } |
3259 | | |
3260 | 0 | mem->seq_add(seq_id, p0, p1, delta); |
3261 | 0 | } |
3262 | | |
3263 | | void llama_memory_seq_div( |
3264 | | llama_memory_t mem, |
3265 | | llama_seq_id seq_id, |
3266 | | llama_pos p0, |
3267 | | llama_pos p1, |
3268 | 0 | int d) { |
3269 | 0 | if (!mem) { |
3270 | 0 | return; |
3271 | 0 | } |
3272 | | |
3273 | 0 | mem->seq_div(seq_id, p0, p1, d); |
3274 | 0 | } |
3275 | | |
3276 | | llama_pos llama_memory_seq_pos_min( |
3277 | | llama_memory_t mem, |
3278 | 0 | llama_seq_id seq_id) { |
3279 | 0 | if (!mem) { |
3280 | 0 | return -1; |
3281 | 0 | } |
3282 | | |
3283 | 0 | return mem->seq_pos_min(seq_id); |
3284 | 0 | } |
3285 | | |
3286 | | llama_pos llama_memory_seq_pos_max( |
3287 | | llama_memory_t mem, |
3288 | 0 | llama_seq_id seq_id) { |
3289 | 0 | if (!mem) { |
3290 | 0 | return -1; |
3291 | 0 | } |
3292 | | |
3293 | 0 | return mem->seq_pos_max(seq_id); |
3294 | 0 | } |
3295 | | |
3296 | 0 | bool llama_memory_can_shift(llama_memory_t mem) { |
3297 | 0 | if (!mem) { |
3298 | 0 | return false; |
3299 | 0 | } |
3300 | | |
3301 | 0 | return mem->get_can_shift(); |
3302 | 0 | } |
3303 | | |
3304 | | // llama state API |
3305 | | |
3306 | | // deprecated |
3307 | 0 | size_t llama_get_state_size(llama_context * ctx) { |
3308 | 0 | return llama_state_get_size(ctx); |
3309 | 0 | } |
3310 | | |
3311 | | // deprecated |
3312 | 0 | size_t llama_copy_state_data(llama_context * ctx, uint8_t * dst) { |
3313 | 0 | return llama_state_get_data(ctx, dst, -1); |
3314 | 0 | } |
3315 | | |
3316 | | // deprecated |
3317 | 0 | size_t llama_set_state_data(llama_context * ctx, const uint8_t * src) { |
3318 | 0 | return llama_state_set_data(ctx, src, -1); |
3319 | 0 | } |
3320 | | |
3321 | | // deprecated |
3322 | 0 | bool llama_load_session_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
3323 | 0 | return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); |
3324 | 0 | } |
3325 | | |
3326 | | // deprecated |
3327 | 0 | bool llama_save_session_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { |
3328 | 0 | return llama_state_save_file(ctx, path_session, tokens, n_token_count); |
3329 | 0 | } |
3330 | | |
3331 | | // Returns the *actual* size of the state. |
3332 | | // Intended to be used when saving to state to a buffer. |
3333 | 0 | size_t llama_state_get_size(llama_context * ctx) { |
3334 | 0 | return ctx->state_get_size(); |
3335 | 0 | } |
3336 | | |
3337 | 0 | size_t llama_state_get_data(llama_context * ctx, uint8_t * dst, size_t size) { |
3338 | 0 | ctx->synchronize(); |
3339 | |
|
3340 | 0 | return ctx->state_get_data(dst, size); |
3341 | 0 | } |
3342 | | |
3343 | | // Sets the state reading from the specified source address |
3344 | 0 | size_t llama_state_set_data(llama_context * ctx, const uint8_t * src, size_t size) { |
3345 | 0 | ctx->synchronize(); |
3346 | |
|
3347 | 0 | return ctx->state_set_data(src, size); |
3348 | 0 | } |
3349 | | |
3350 | 0 | bool llama_state_load_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
3351 | 0 | ctx->synchronize(); |
3352 | |
|
3353 | 0 | try { |
3354 | 0 | return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out); |
3355 | 0 | } catch (const std::exception & err) { |
3356 | 0 | LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what()); |
3357 | 0 | return false; |
3358 | 0 | } |
3359 | 0 | } |
3360 | | |
3361 | 0 | bool llama_state_save_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { |
3362 | 0 | ctx->synchronize(); |
3363 | |
|
3364 | 0 | try { |
3365 | 0 | return ctx->state_save_file(path_session, tokens, n_token_count); |
3366 | 0 | } catch (const std::exception & err) { |
3367 | 0 | LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what()); |
3368 | 0 | return false; |
3369 | 0 | } |
3370 | 0 | } |
3371 | | |
3372 | 0 | size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) { |
3373 | 0 | return llama_state_seq_get_size_ext(ctx, seq_id, 0); |
3374 | 0 | } |
3375 | | |
3376 | 0 | size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { |
3377 | 0 | return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0); |
3378 | 0 | } |
3379 | | |
3380 | 0 | size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { |
3381 | 0 | return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0); |
3382 | 0 | } |
3383 | | |
3384 | 0 | size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) { |
3385 | 0 | return ctx->state_seq_get_size(seq_id, flags); |
3386 | 0 | } |
3387 | | |
3388 | 0 | size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { |
3389 | 0 | ctx->synchronize(); |
3390 | |
|
3391 | 0 | return ctx->state_seq_get_data(seq_id, dst, size, flags); |
3392 | 0 | } |
3393 | | |
3394 | 0 | size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { |
3395 | 0 | ctx->synchronize(); |
3396 | |
|
3397 | 0 | return ctx->state_seq_set_data(seq_id, src, size, flags); |
3398 | 0 | } |
3399 | | |
3400 | 0 | size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { |
3401 | 0 | ctx->synchronize(); |
3402 | |
|
3403 | 0 | try { |
3404 | 0 | return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count); |
3405 | 0 | } catch (const std::exception & err) { |
3406 | 0 | LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what()); |
3407 | 0 | return 0; |
3408 | 0 | } |
3409 | 0 | } |
3410 | | |
3411 | 0 | size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
3412 | 0 | ctx->synchronize(); |
3413 | |
|
3414 | 0 | try { |
3415 | 0 | return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out); |
3416 | 0 | } catch (const std::exception & err) { |
3417 | 0 | LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what()); |
3418 | 0 | return 0; |
3419 | 0 | } |
3420 | 0 | } |
3421 | | |
3422 | | /// |
3423 | | |
3424 | | int32_t llama_encode( |
3425 | | llama_context * ctx, |
3426 | 0 | llama_batch batch) { |
3427 | 0 | const int ret = ctx->encode(batch); |
3428 | 0 | if (ret != 0) { |
3429 | 0 | LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); |
3430 | 0 | } |
3431 | |
|
3432 | 0 | return ret; |
3433 | 0 | } |
3434 | | |
3435 | | int32_t llama_decode( |
3436 | | llama_context * ctx, |
3437 | 0 | llama_batch batch) { |
3438 | 0 | const int ret = ctx->decode(batch); |
3439 | 0 | if (ret != 0 && ret != 1) { |
3440 | 0 | LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); |
3441 | 0 | } |
3442 | |
|
3443 | 0 | return ret; |
3444 | 0 | } |
3445 | | |
3446 | | // |
3447 | | // perf |
3448 | | // |
3449 | | |
3450 | 0 | llama_perf_context_data llama_perf_context(const llama_context * ctx) { |
3451 | 0 | llama_perf_context_data data = {}; |
3452 | |
|
3453 | 0 | if (ctx == nullptr) { |
3454 | 0 | return data; |
3455 | 0 | } |
3456 | | |
3457 | 0 | data = ctx->perf_get_data(); |
3458 | |
|
3459 | 0 | return data; |
3460 | 0 | } |
3461 | | |
3462 | 0 | void llama_perf_context_print(const llama_context * ctx) { |
3463 | 0 | const auto data = llama_perf_context(ctx); |
3464 | |
|
3465 | 0 | const double t_end_ms = 1e-3 * ggml_time_us(); |
3466 | |
|
3467 | 0 | LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms); |
3468 | 0 | LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", |
3469 | 0 | __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval); |
3470 | 0 | LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", |
3471 | 0 | __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval); |
3472 | 0 | LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval)); |
3473 | 0 | LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused); |
3474 | 0 | } |
3475 | | |
3476 | 0 | void llama_perf_context_reset(llama_context * ctx) { |
3477 | 0 | ctx->perf_reset(); |
3478 | 0 | } |
3479 | | |
3480 | 0 | void llama_memory_breakdown_print(const struct llama_context * ctx) { |
3481 | 0 | const std::vector<ggml_backend_dev_t> & devices = ctx->get_model().devices; |
3482 | |
|
3483 | 0 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown(); |
3484 | |
|
3485 | 0 | std::vector<std::array<std::string, 9>> table_data; |
3486 | 0 | table_data.reserve(devices.size()); |
3487 | 0 | const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n"; |
3488 | 0 | const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n"; |
3489 | 0 | const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n"; |
3490 | |
|
3491 | 0 | table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"}); |
3492 | |
|
3493 | 0 | constexpr size_t MiB = 1024 * 1024; |
3494 | 0 | const std::vector<std::string> desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "}; |
3495 | | |
3496 | | // track seen buffer types to avoid double counting: |
3497 | 0 | std::set<ggml_backend_buffer_type_t> seen_buffer_types; |
3498 | | |
3499 | | // accumulative memory breakdown for each device and for host: |
3500 | 0 | std::vector<llama_memory_breakdown_data> mb_dev(devices.size()); |
3501 | 0 | llama_memory_breakdown_data mb_host; |
3502 | |
|
3503 | 0 | for (const auto & buft_mb : memory_breakdown) { |
3504 | 0 | ggml_backend_buffer_type_t buft = buft_mb.first; |
3505 | 0 | const llama_memory_breakdown_data & mb = buft_mb.second; |
3506 | 0 | if (ggml_backend_buft_is_host(buft)) { |
3507 | 0 | mb_host.model += mb.model; |
3508 | 0 | mb_host.context += mb.context; |
3509 | 0 | mb_host.compute += mb.compute; |
3510 | 0 | seen_buffer_types.insert(buft); |
3511 | 0 | continue; |
3512 | 0 | } |
3513 | 0 | ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); |
3514 | 0 | if (dev) { |
3515 | 0 | int i_dev = -1; |
3516 | 0 | for (size_t i = 0; i < devices.size(); i++) { |
3517 | 0 | if (devices[i] == dev) { |
3518 | 0 | i_dev = i; |
3519 | 0 | break; |
3520 | 0 | } |
3521 | 0 | } |
3522 | 0 | if (i_dev != -1) { |
3523 | 0 | mb_dev[i_dev].model += mb.model; |
3524 | 0 | mb_dev[i_dev].context += mb.context; |
3525 | 0 | mb_dev[i_dev].compute += mb.compute; |
3526 | 0 | seen_buffer_types.insert(buft); |
3527 | 0 | continue; |
3528 | 0 | } |
3529 | 0 | } |
3530 | 0 | } |
3531 | | |
3532 | | // print memory breakdown for each device: |
3533 | 0 | for (size_t i = 0; i < devices.size(); i++) { |
3534 | 0 | ggml_backend_dev_t dev = devices[i]; |
3535 | 0 | llama_memory_breakdown_data mb = mb_dev[i]; |
3536 | |
|
3537 | 0 | const std::string name = ggml_backend_dev_name(dev); |
3538 | 0 | std::string desc = ggml_backend_dev_description(dev); |
3539 | 0 | for (const std::string & prefix : desc_prefixes_strip) { |
3540 | 0 | if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) { |
3541 | 0 | desc = desc.substr(prefix.length()); |
3542 | 0 | } |
3543 | 0 | } |
3544 | |
|
3545 | 0 | size_t free, total; |
3546 | 0 | ggml_backend_dev_memory(dev, &free, &total); |
3547 | |
|
3548 | 0 | const size_t self = mb.model + mb.context + mb.compute; |
3549 | 0 | const size_t unaccounted = total - self - free; |
3550 | |
|
3551 | 0 | table_data.push_back({ |
3552 | 0 | template_gpu, |
3553 | 0 | " - " + name + " (" + desc + ")", |
3554 | 0 | std::to_string(total / MiB), |
3555 | 0 | std::to_string(free / MiB), |
3556 | 0 | std::to_string(self / MiB), |
3557 | 0 | std::to_string(mb.model / MiB), |
3558 | 0 | std::to_string(mb.context / MiB), |
3559 | 0 | std::to_string(mb.compute / MiB), |
3560 | 0 | std::to_string(unaccounted / MiB)}); |
3561 | 0 | } |
3562 | | |
3563 | | // print memory breakdown for host: |
3564 | 0 | { |
3565 | 0 | const size_t self = mb_host.model + mb_host.context + mb_host.compute; |
3566 | 0 | table_data.push_back({ |
3567 | 0 | template_other, |
3568 | 0 | " - Host", |
3569 | 0 | "", // total |
3570 | 0 | "", // free |
3571 | 0 | std::to_string(self / MiB), |
3572 | 0 | std::to_string(mb_host.model / MiB), |
3573 | 0 | std::to_string(mb_host.context / MiB), |
3574 | 0 | std::to_string(mb_host.compute / MiB), |
3575 | 0 | ""}); // unaccounted |
3576 | 0 | } |
3577 | | |
3578 | | // print memory breakdown for all remaining buffer types: |
3579 | 0 | for (const auto & buft_mb : memory_breakdown) { |
3580 | 0 | ggml_backend_buffer_type_t buft = buft_mb.first; |
3581 | 0 | const llama_memory_breakdown_data & mb = buft_mb.second; |
3582 | 0 | if (seen_buffer_types.count(buft) == 1) { |
3583 | 0 | continue; |
3584 | 0 | } |
3585 | 0 | const std::string name = ggml_backend_buft_name(buft); |
3586 | 0 | const size_t self = mb.model + mb.context + mb.compute; |
3587 | 0 | table_data.push_back({ |
3588 | 0 | template_other, |
3589 | 0 | " - " + name, |
3590 | 0 | "", // total |
3591 | 0 | "", // free |
3592 | 0 | std::to_string(self / MiB), |
3593 | 0 | std::to_string(mb.model / MiB), |
3594 | 0 | std::to_string(mb.context / MiB), |
3595 | 0 | std::to_string(mb.compute / MiB), |
3596 | 0 | ""}); // unaccounted |
3597 | 0 | seen_buffer_types.insert(buft); |
3598 | 0 | } |
3599 | |
|
3600 | 0 | for (size_t j = 1; j < table_data[0].size(); j++) { |
3601 | 0 | size_t max_len = 0; |
3602 | 0 | for (const auto & td : table_data) { |
3603 | 0 | max_len = std::max(max_len, td[j].length()); |
3604 | 0 | } |
3605 | 0 | for (auto & td : table_data) { |
3606 | 0 | td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' '); |
3607 | 0 | } |
3608 | 0 | } |
3609 | 0 | for (const auto & td : table_data) { |
3610 | 0 | LLAMA_LOG_INFO(td[0].c_str(), |
3611 | 0 | __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(), |
3612 | 0 | td[6].c_str(), td[7].c_str(), td[8].c_str()); |
3613 | 0 | } |
3614 | 0 | } |
3615 | | |
3616 | | // |
3617 | | // training |
3618 | | // |
3619 | | |
3620 | 0 | bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) { |
3621 | 0 | GGML_UNUSED(tensor); |
3622 | 0 | GGML_UNUSED(userdata); |
3623 | 0 | return true; |
3624 | 0 | } |
3625 | | |
3626 | 0 | void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) { |
3627 | 0 | ctx->opt_init(model, lopt_params); |
3628 | 0 | } |
3629 | | |
3630 | | void llama_opt_epoch( |
3631 | | struct llama_context * ctx, |
3632 | | ggml_opt_dataset_t dataset, |
3633 | | ggml_opt_result_t result_train, |
3634 | | ggml_opt_result_t result_eval, |
3635 | | int64_t idata_split, |
3636 | | ggml_opt_epoch_callback callback_train, |
3637 | 0 | ggml_opt_epoch_callback callback_eval) { |
3638 | 0 | ctx->opt_epoch( |
3639 | 0 | dataset, |
3640 | 0 | result_train, |
3641 | 0 | result_eval, |
3642 | 0 | idata_split, |
3643 | 0 | callback_train, |
3644 | 0 | callback_eval); |
3645 | 0 | } |