/src/llama.cpp/src/llama-context.cpp
Line | Count | Source |
1 | | #include "llama-context.h" |
2 | | |
3 | | #include "llama-arch.h" |
4 | | #include "llama-impl.h" |
5 | | #include "llama-batch.h" |
6 | | #include "llama-io.h" |
7 | | #include "llama-memory.h" |
8 | | #include "llama-mmap.h" |
9 | | #include "llama-model.h" |
10 | | |
11 | | #include <cinttypes> |
12 | | #include <cmath> |
13 | | #include <cstring> |
14 | | #include <limits> |
15 | | #include <stdexcept> |
16 | | |
17 | | // |
18 | | // llama_context |
19 | | // |
20 | | |
21 | | llama_context::llama_context( |
22 | | const llama_model & model, |
23 | | llama_context_params params) : |
24 | 0 | model(model), |
25 | 0 | balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) { |
26 | | // TODO warning when creating llama_context with awkward ctx size that is not a power of 2, |
27 | | // may need to be backend-dependent |
28 | 0 | LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); |
29 | |
|
30 | 0 | t_start_us = model.t_start_us; |
31 | 0 | t_load_us = model.t_load_us; |
32 | |
|
33 | 0 | const auto & hparams = model.hparams; |
34 | |
|
35 | 0 | cparams.n_seq_max = std::max(1u, params.n_seq_max); |
36 | 0 | if (cparams.n_seq_max > LLAMA_MAX_SEQ) { |
37 | 0 | throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); |
38 | 0 | } |
39 | | |
40 | 0 | cparams.n_threads = params.n_threads; |
41 | 0 | cparams.n_threads_batch = params.n_threads_batch; |
42 | 0 | cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; |
43 | 0 | cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; |
44 | 0 | cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; |
45 | 0 | cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; |
46 | 0 | cparams.embeddings = params.embeddings; |
47 | 0 | cparams.offload_kqv = params.offload_kqv; |
48 | 0 | cparams.no_perf = params.no_perf; |
49 | 0 | cparams.pooling_type = params.pooling_type; |
50 | 0 | cparams.warmup = false; |
51 | |
|
52 | 0 | cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; |
53 | 0 | cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; |
54 | 0 | cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; |
55 | |
|
56 | 0 | cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : |
57 | 0 | hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn : |
58 | 0 | hparams.n_ctx_train; |
59 | |
|
60 | 0 | cparams.cb_eval = params.cb_eval; |
61 | 0 | cparams.cb_eval_user_data = params.cb_eval_user_data; |
62 | |
|
63 | 0 | auto rope_scaling_type = params.rope_scaling_type; |
64 | 0 | if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { |
65 | 0 | rope_scaling_type = hparams.rope_scaling_type_train; |
66 | 0 | } |
67 | |
|
68 | 0 | if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) { |
69 | 0 | cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none |
70 | 0 | } |
71 | |
|
72 | 0 | if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set' |
73 | 0 | cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; |
74 | 0 | } |
75 | |
|
76 | 0 | if (cparams.yarn_ext_factor != 0) { |
77 | 0 | static auto get_mscale = [](float scale, float mscale) { |
78 | 0 | return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f); |
79 | 0 | }; |
80 | |
|
81 | 0 | const float factor = 1.0f / cparams.rope_freq_scale; |
82 | | |
83 | | // ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348 |
84 | 0 | if (hparams.rope_yarn_log_mul != 0.0f) { |
85 | | // note: here we assume `mscale == 1.0f` |
86 | | // TODO: start reading the actual value of mscale and handle the case where it is not 1.0f |
87 | 0 | float mscale = 1.0f; |
88 | 0 | const float mscale_all_dims = hparams.rope_yarn_log_mul; |
89 | | |
90 | | // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] |
91 | | // special-case DEEPSEEK v2: |
92 | | // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43 |
93 | 0 | if (model.arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) { |
94 | 0 | mscale = mscale_all_dims; |
95 | 0 | } |
96 | |
|
97 | 0 | cparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims); |
98 | |
|
99 | 0 | LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n", |
100 | 0 | __func__, cparams.yarn_attn_factor, mscale, mscale_all_dims); |
101 | 0 | } else { |
102 | 0 | cparams.yarn_attn_factor = get_mscale(factor, 1.0f); |
103 | 0 | } |
104 | | |
105 | | // when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor: |
106 | | // https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544 |
107 | | // |
108 | | // ref: https://github.com/ggml-org/llama.cpp/discussions/7416 |
109 | | // https://github.com/ggml-org/llama.cpp/pull/17945 |
110 | 0 | cparams.yarn_attn_factor *= 1.0f / (1.0f + 0.1f * logf(factor)); |
111 | 0 | } |
112 | |
|
113 | 0 | cparams.yarn_attn_factor *= hparams.rope_attn_factor; |
114 | |
|
115 | 0 | if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { |
116 | 0 | if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { |
117 | 0 | cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; |
118 | 0 | } else { |
119 | 0 | cparams.pooling_type = hparams.pooling_type; |
120 | 0 | } |
121 | 0 | } |
122 | |
|
123 | 0 | if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) { |
124 | 0 | cparams.causal_attn = hparams.causal_attn; |
125 | 0 | } else { |
126 | 0 | cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; |
127 | 0 | } |
128 | |
|
129 | 0 | cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; |
130 | | |
131 | | // with causal attention, the batch size is limited by the context size |
132 | 0 | cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; |
133 | |
|
134 | 0 | cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); |
135 | |
|
136 | 0 | cparams.op_offload = params.op_offload; |
137 | 0 | cparams.kv_unified = params.kv_unified; |
138 | |
|
139 | 0 | { |
140 | 0 | const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); |
141 | 0 | graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; |
142 | |
|
143 | 0 | if (graph_reuse_disable) { |
144 | 0 | LLAMA_LOG_WARN("%s: graph reuse disabled\n", __func__); |
145 | 0 | } |
146 | 0 | } |
147 | | |
148 | | // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732 |
149 | 0 | cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); |
150 | |
|
151 | 0 | if (cparams.kv_unified) { |
152 | 0 | cparams.n_ctx_seq = cparams.n_ctx; |
153 | 0 | } else { |
154 | 0 | cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max; |
155 | 0 | cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256); |
156 | |
|
157 | 0 | if (cparams.n_ctx_seq == 0) { |
158 | 0 | throw std::runtime_error("n_ctx_seq == 0"); |
159 | 0 | } |
160 | | |
161 | 0 | if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) { |
162 | 0 | cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max; |
163 | 0 | LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx); |
164 | 0 | } |
165 | 0 | } |
166 | | |
167 | 0 | LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); |
168 | 0 | LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); |
169 | 0 | LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq); |
170 | 0 | LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); |
171 | 0 | LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); |
172 | 0 | LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); |
173 | 0 | LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type)); |
174 | 0 | LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); |
175 | 0 | LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); |
176 | 0 | LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); |
177 | |
|
178 | 0 | if (cparams.n_ctx_seq < hparams.n_ctx_train) { |
179 | 0 | LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", |
180 | 0 | __func__, cparams.n_ctx_seq, hparams.n_ctx_train); |
181 | 0 | } |
182 | |
|
183 | 0 | if (cparams.n_ctx_seq > hparams.n_ctx_train) { |
184 | 0 | LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", |
185 | 0 | __func__, cparams.n_ctx_seq, hparams.n_ctx_train); |
186 | 0 | } |
187 | |
|
188 | 0 | if (!hparams.vocab_only) { |
189 | | // GPU backends |
190 | 0 | for (auto * dev : model.devices) { |
191 | 0 | ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); |
192 | 0 | if (backend == nullptr) { |
193 | 0 | throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); |
194 | 0 | } |
195 | 0 | backends.emplace_back(backend); |
196 | 0 | } |
197 | | |
198 | | // add ACCEL backends (such as BLAS) |
199 | 0 | for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { |
200 | 0 | ggml_backend_dev_t dev = ggml_backend_dev_get(i); |
201 | 0 | if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { |
202 | 0 | ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); |
203 | 0 | if (backend == nullptr) { |
204 | 0 | throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); |
205 | 0 | } |
206 | 0 | backends.emplace_back(backend); |
207 | 0 | } |
208 | 0 | } |
209 | | |
210 | | // add CPU backend |
211 | 0 | backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); |
212 | 0 | if (backend_cpu == nullptr) { |
213 | 0 | throw std::runtime_error("failed to initialize CPU backend"); |
214 | 0 | } |
215 | 0 | backends.emplace_back(backend_cpu); |
216 | | |
217 | | // create a list of the set_n_threads functions in the backends |
218 | 0 | for (auto & backend : backends) { |
219 | 0 | ggml_backend_dev_t dev = ggml_backend_get_device(backend.get()); |
220 | 0 | ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; |
221 | 0 | if (reg) { |
222 | 0 | auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); |
223 | 0 | if (ggml_backend_set_n_threads_fn) { |
224 | 0 | set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn); |
225 | 0 | } |
226 | 0 | } |
227 | 0 | } |
228 | |
|
229 | 0 | llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data); |
230 | | |
231 | | // graph outputs buffer |
232 | 0 | { |
233 | | // resized during inference when a batch uses more outputs |
234 | 0 | if (output_reserve(params.n_seq_max) < params.n_seq_max) { |
235 | 0 | throw std::runtime_error("failed to reserve initial output buffer"); |
236 | 0 | } |
237 | | |
238 | 0 | LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__, |
239 | 0 | ggml_backend_buffer_name (buf_output.get()), |
240 | 0 | ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0); |
241 | 0 | } |
242 | 0 | } |
243 | | |
244 | | // init the memory module |
245 | 0 | if (!hparams.vocab_only) { |
246 | 0 | llama_memory_params params_mem = { |
247 | 0 | /*.type_k =*/ params.type_k, |
248 | 0 | /*.type_v =*/ params.type_v, |
249 | 0 | /*.swa_full =*/ params.swa_full, |
250 | 0 | }; |
251 | |
|
252 | 0 | memory.reset(model.create_memory(params_mem, cparams)); |
253 | 0 | } |
254 | | |
255 | | // init backends |
256 | 0 | if (!hparams.vocab_only) { |
257 | 0 | LLAMA_LOG_DEBUG("%s: enumerating backends\n", __func__); |
258 | |
|
259 | 0 | backend_buft.clear(); |
260 | 0 | backend_ptrs.clear(); |
261 | 0 | backend_buf_exp_size.clear(); |
262 | |
|
263 | 0 | for (auto & backend : backends) { |
264 | 0 | auto * buft = ggml_backend_get_default_buffer_type(backend.get()); |
265 | 0 | auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); |
266 | |
|
267 | 0 | if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) { |
268 | | // use the host buffer of the first device CPU for faster transfer of the intermediate state |
269 | 0 | auto * dev = model.devices[0]; |
270 | 0 | auto * host_buft = ggml_backend_dev_host_buffer_type(dev); |
271 | 0 | if (host_buft) { |
272 | 0 | buft = host_buft; |
273 | 0 | } |
274 | 0 | } |
275 | |
|
276 | 0 | backend_buft.push_back(buft); |
277 | 0 | backend_ptrs.push_back(backend.get()); |
278 | 0 | backend_buf_exp_size.push_back(0); |
279 | 0 | } |
280 | |
|
281 | 0 | LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); |
282 | |
|
283 | 0 | const uint32_t n_seqs = cparams.n_seq_max; |
284 | 0 | const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); |
285 | |
|
286 | 0 | const size_t max_nodes = this->graph_max_nodes(n_tokens); |
287 | |
|
288 | 0 | LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); |
289 | |
|
290 | 0 | gf_res_prev.reset(new llm_graph_result(max_nodes)); |
291 | 0 | gf_res_reserve.reset(new llm_graph_result(max_nodes)); |
292 | | |
293 | | // TODO: move these checks to ggml_backend_sched |
294 | | // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary |
295 | 0 | bool pipeline_parallel = |
296 | 0 | model.n_devices() > 1 && |
297 | 0 | model.n_gpu_layers() > model.hparams.n_layer && |
298 | 0 | model.split_mode() == LLAMA_SPLIT_MODE_LAYER && |
299 | 0 | cparams.offload_kqv && |
300 | 0 | !model.has_tensor_overrides(); |
301 | | |
302 | | // pipeline parallelism requires support for async compute and events in all devices |
303 | 0 | if (pipeline_parallel) { |
304 | 0 | for (auto & backend : backends) { |
305 | 0 | auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); |
306 | 0 | if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { |
307 | | // ignore CPU backend |
308 | 0 | continue; |
309 | 0 | } |
310 | 0 | auto * dev = ggml_backend_get_device(backend.get()); |
311 | 0 | ggml_backend_dev_props props; |
312 | 0 | ggml_backend_dev_get_props(dev, &props); |
313 | 0 | if (!props.caps.async || !props.caps.events) { |
314 | | // device does not support async compute or events |
315 | 0 | pipeline_parallel = false; |
316 | 0 | break; |
317 | 0 | } |
318 | 0 | } |
319 | 0 | } |
320 | |
|
321 | 0 | sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); |
322 | |
|
323 | 0 | if (pipeline_parallel) { |
324 | 0 | LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); |
325 | 0 | } |
326 | |
|
327 | 0 | llama_memory_context_ptr mctx; |
328 | 0 | if (memory) { |
329 | 0 | LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); |
330 | 0 | mctx = memory->init_full(); |
331 | 0 | if (!mctx) { |
332 | 0 | throw std::runtime_error("failed to initialize memory module"); |
333 | 0 | } |
334 | 0 | } |
335 | | |
336 | 0 | cross.v_embd.clear(); |
337 | | |
338 | | // avoid reserving graphs with zero outputs - assume one output per sequence |
339 | 0 | n_outputs = n_seqs; |
340 | |
|
341 | 0 | LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); |
342 | | |
343 | | // resolve automatic Flash Attention use |
344 | 0 | if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { |
345 | 0 | auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); |
346 | 0 | if (!gf) { |
347 | 0 | throw std::runtime_error("failed to split graph for Flash Attention check"); |
348 | 0 | } |
349 | | |
350 | 0 | const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; |
351 | 0 | bool fa_device_mismatch = false; |
352 | 0 | for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { |
353 | 0 | ggml_tensor * n = ggml_graph_node(gf, i); |
354 | 0 | if (n->op != GGML_OP_FLASH_ATTN_EXT) { |
355 | 0 | continue; |
356 | 0 | } |
357 | 0 | ggml_backend_dev_t device_fa = ggml_backend_get_device( |
358 | 0 | ggml_backend_sched_get_tensor_backend(sched.get(), n)); |
359 | | |
360 | | // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer |
361 | 0 | GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); |
362 | 0 | const int il = std::stoi(n->name + prefix_len); |
363 | 0 | ggml_backend_dev_t device_kv = model.dev_layer(il); |
364 | 0 | if (device_fa != device_kv) { |
365 | 0 | LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " |
366 | 0 | "is assigned to device %s (usually due to missing support)\n", |
367 | 0 | __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); |
368 | | // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways |
369 | 0 | fa_device_mismatch = true; |
370 | 0 | break; |
371 | 0 | } |
372 | 0 | } |
373 | 0 | if (fa_device_mismatch) { |
374 | 0 | cparams.flash_attn = false; |
375 | 0 | LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); |
376 | 0 | if (ggml_is_quantized(params.type_v)) { |
377 | 0 | throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); |
378 | 0 | } |
379 | 0 | } else { |
380 | 0 | cparams.flash_attn = true; |
381 | 0 | LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); |
382 | 0 | } |
383 | 0 | } |
384 | | |
385 | | // reserve worst-case graph |
386 | 0 | int n_splits_pp = -1; |
387 | 0 | int n_nodes_pp = -1; |
388 | |
|
389 | 0 | int n_splits_tg = -1; |
390 | 0 | int n_nodes_tg = -1; |
391 | | |
392 | | // reserve pp (prompt processing) graph first so that buffers are only allocated once |
393 | 0 | { |
394 | 0 | auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), |
395 | 0 | model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); |
396 | 0 | if (!gf) { |
397 | 0 | if (pipeline_parallel) { |
398 | 0 | LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); |
399 | 0 | sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); |
400 | 0 | gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); |
401 | 0 | } |
402 | 0 | if (!gf) { |
403 | 0 | throw std::runtime_error("failed to allocate compute pp buffers"); |
404 | 0 | } |
405 | 0 | } |
406 | | |
407 | 0 | n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); |
408 | 0 | n_nodes_pp = ggml_graph_n_nodes(gf); |
409 | 0 | } |
410 | | |
411 | | // reserve with tg (token generation) graph to get the number of splits and nodes |
412 | 0 | { |
413 | 0 | auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); |
414 | 0 | if (!gf) { |
415 | 0 | throw std::runtime_error("failed to allocate compute tg buffers"); |
416 | 0 | } |
417 | | |
418 | 0 | n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); |
419 | 0 | n_nodes_tg = ggml_graph_n_nodes(gf); |
420 | 0 | } |
421 | | |
422 | | // reserve again with pp graph to avoid ggml-alloc reallocations during inference |
423 | 0 | { |
424 | | // TODO: not sure if the following graph would be worster case for multi-stream KV caches: |
425 | | // |
426 | | // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); |
427 | | // |
428 | 0 | auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); |
429 | 0 | if (!gf) { |
430 | 0 | throw std::runtime_error("failed to allocate compute pp buffers"); |
431 | 0 | } |
432 | 0 | } |
433 | | |
434 | 0 | for (size_t i = 0; i < backend_ptrs.size(); ++i) { |
435 | 0 | ggml_backend_t backend = backend_ptrs[i]; |
436 | 0 | ggml_backend_buffer_type_t buft = backend_buft[i]; |
437 | 0 | if (!model.hparams.no_alloc) { |
438 | 0 | backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); |
439 | 0 | } |
440 | 0 | if (backend_buf_exp_size[i] > 1) { |
441 | 0 | LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, |
442 | 0 | ggml_backend_buft_name(buft), |
443 | 0 | backend_buf_exp_size[i] / 1024.0 / 1024.0); |
444 | 0 | } |
445 | 0 | } |
446 | |
|
447 | 0 | if (n_nodes_pp == n_nodes_tg) { |
448 | 0 | LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); |
449 | 0 | } else { |
450 | 0 | LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); |
451 | 0 | } |
452 | |
|
453 | 0 | if (n_splits_pp == n_splits_tg) { |
454 | 0 | LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); |
455 | 0 | } else { |
456 | 0 | LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); |
457 | 0 | } |
458 | 0 | } |
459 | 0 | } |
460 | | |
461 | 0 | llama_context::~llama_context() { |
462 | 0 | if (!model.hparams.no_alloc) { |
463 | 0 | for (size_t i = 0; i < backend_ptrs.size(); ++i) { |
464 | 0 | ggml_backend_t backend = backend_ptrs[i]; |
465 | 0 | ggml_backend_buffer_type_t buft = backend_buft[i]; |
466 | |
|
467 | 0 | const size_t size_exp = backend_buf_exp_size[i]; |
468 | 0 | const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); |
469 | 0 | if (size_exp == size_act) { |
470 | 0 | LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", |
471 | 0 | __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); |
472 | 0 | } else { |
473 | 0 | LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", |
474 | 0 | __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); |
475 | 0 | } |
476 | 0 | } |
477 | 0 | } |
478 | 0 | ggml_opt_free(opt_ctx); |
479 | 0 | } |
480 | | |
481 | 0 | void llama_context::synchronize() { |
482 | 0 | ggml_backend_sched_synchronize(sched.get()); |
483 | | |
484 | | // FIXME: if multiple single tokens are evaluated without a synchronization, |
485 | | // the stats will be added to the prompt evaluation stats |
486 | | // this should only happen when using batch size 1 to evaluate a batch |
487 | | |
488 | | // add the evaluation to the stats |
489 | 0 | if (n_queued_tokens == 1) { |
490 | 0 | if (!cparams.no_perf) { |
491 | 0 | t_eval_us += ggml_time_us() - t_compute_start_us; |
492 | 0 | } |
493 | 0 | n_eval++; |
494 | 0 | } else if (n_queued_tokens > 1) { |
495 | 0 | if (!cparams.no_perf) { |
496 | 0 | t_p_eval_us += ggml_time_us() - t_compute_start_us; |
497 | 0 | } |
498 | 0 | n_p_eval += n_queued_tokens; |
499 | 0 | } |
500 | | |
501 | | // get a more accurate load time, upon first eval |
502 | 0 | if (n_queued_tokens > 0 && !has_evaluated_once) { |
503 | 0 | t_load_us = ggml_time_us() - t_start_us; |
504 | 0 | has_evaluated_once = true; |
505 | 0 | } |
506 | |
|
507 | 0 | n_queued_tokens = 0; |
508 | 0 | t_compute_start_us = 0; |
509 | 0 | } |
510 | | |
511 | 0 | const llama_model & llama_context::get_model() const { |
512 | 0 | return model; |
513 | 0 | } |
514 | | |
515 | 0 | const llama_cparams & llama_context::get_cparams() const { |
516 | 0 | return cparams; |
517 | 0 | } |
518 | | |
519 | 0 | ggml_backend_sched_t llama_context::get_sched() const { |
520 | 0 | return sched.get(); |
521 | 0 | } |
522 | | |
523 | 0 | uint32_t llama_context::n_ctx() const { |
524 | 0 | return cparams.n_ctx; |
525 | 0 | } |
526 | | |
527 | 0 | uint32_t llama_context::n_ctx_seq() const { |
528 | 0 | return cparams.n_ctx_seq; |
529 | 0 | } |
530 | | |
531 | 0 | uint32_t llama_context::n_batch() const { |
532 | 0 | return cparams.n_batch; |
533 | 0 | } |
534 | | |
535 | 0 | uint32_t llama_context::n_ubatch() const { |
536 | 0 | return cparams.n_ubatch; |
537 | 0 | } |
538 | | |
539 | 0 | uint32_t llama_context::n_seq_max() const { |
540 | 0 | return cparams.n_seq_max; |
541 | 0 | } |
542 | | |
543 | 0 | uint32_t llama_context::n_threads() const { |
544 | 0 | return cparams.n_threads; |
545 | 0 | } |
546 | | |
547 | 0 | uint32_t llama_context::n_threads_batch() const { |
548 | 0 | return cparams.n_threads_batch; |
549 | 0 | } |
550 | | |
551 | 0 | llama_memory_t llama_context::get_memory() const { |
552 | 0 | return memory.get(); |
553 | 0 | } |
554 | | |
555 | 0 | bool llama_context::memory_update(bool optimize) { |
556 | 0 | if (!memory) { |
557 | 0 | return false; |
558 | 0 | } |
559 | | |
560 | 0 | { |
561 | 0 | const auto mctx = memory->init_update(this, optimize); |
562 | 0 | switch (mctx->get_status()) { |
563 | 0 | case LLAMA_MEMORY_STATUS_SUCCESS: |
564 | 0 | { |
565 | | // noop |
566 | 0 | } break; |
567 | 0 | case LLAMA_MEMORY_STATUS_NO_UPDATE: |
568 | 0 | { |
569 | | // no updates need to be performed |
570 | 0 | return false; |
571 | 0 | } |
572 | 0 | case LLAMA_MEMORY_STATUS_FAILED_PREPARE: |
573 | 0 | case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: |
574 | 0 | { |
575 | 0 | LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__); |
576 | 0 | return false; |
577 | 0 | } |
578 | 0 | } |
579 | | |
580 | | // reset the previous graph result to make sure that it won't be reused |
581 | | // TODO: change the mctx->apply() to return information if a graph reserve is needed |
582 | | // reset the graph result only if the memory module did reset the scheduler |
583 | 0 | gf_res_prev->reset(); |
584 | |
|
585 | 0 | if (!mctx->apply()) { |
586 | 0 | LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); |
587 | 0 | } |
588 | 0 | } |
589 | | |
590 | | // if the memory module did any computation, we have to reserve a new worst-case graph |
591 | 0 | { |
592 | 0 | const auto mctx = memory->init_full(); |
593 | 0 | if (!mctx) { |
594 | 0 | throw std::runtime_error("failed to initialize memory context"); |
595 | 0 | } |
596 | | |
597 | 0 | const uint32_t n_seqs = cparams.n_seq_max; |
598 | 0 | const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); |
599 | |
|
600 | 0 | auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); |
601 | 0 | if (!gf) { |
602 | 0 | LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); |
603 | 0 | } |
604 | 0 | } |
605 | | |
606 | 0 | return true; |
607 | 0 | } |
608 | | |
609 | 0 | enum llama_pooling_type llama_context::pooling_type() const { |
610 | 0 | return cparams.pooling_type; |
611 | 0 | } |
612 | | |
613 | 0 | float * llama_context::get_logits() { |
614 | 0 | output_reorder(); |
615 | |
|
616 | 0 | return logits; |
617 | 0 | } |
618 | | |
619 | 0 | float * llama_context::get_logits_ith(int32_t i) { |
620 | 0 | int64_t j = -1; |
621 | |
|
622 | 0 | output_reorder(); |
623 | |
|
624 | 0 | try { |
625 | 0 | if (logits == nullptr) { |
626 | 0 | throw std::runtime_error("no logits"); |
627 | 0 | } |
628 | | |
629 | 0 | if (i < 0) { |
630 | 0 | j = n_outputs + i; |
631 | 0 | if (j < 0) { |
632 | 0 | throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); |
633 | 0 | } |
634 | 0 | } else if ((size_t) i >= output_ids.size()) { |
635 | 0 | throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); |
636 | 0 | } else { |
637 | 0 | j = output_ids[i]; |
638 | 0 | } |
639 | | |
640 | 0 | if (j < 0) { |
641 | 0 | throw std::runtime_error(format("batch.logits[%d] != true", i)); |
642 | 0 | } |
643 | 0 | if (j >= n_outputs) { |
644 | | // This should not happen |
645 | 0 | throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); |
646 | 0 | } |
647 | | |
648 | 0 | return logits + j*model.vocab.n_tokens(); |
649 | 0 | } catch (const std::exception & err) { |
650 | 0 | LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); |
651 | | #ifndef NDEBUG |
652 | | GGML_ABORT("fatal error"); |
653 | | #else |
654 | 0 | return nullptr; |
655 | 0 | #endif |
656 | 0 | } |
657 | 0 | } |
658 | | |
659 | 0 | float * llama_context::get_embeddings() { |
660 | 0 | output_reorder(); |
661 | |
|
662 | 0 | return embd; |
663 | 0 | } |
664 | | |
665 | 0 | float * llama_context::get_embeddings_ith(int32_t i) { |
666 | 0 | int64_t j = -1; |
667 | |
|
668 | 0 | output_reorder(); |
669 | |
|
670 | 0 | try { |
671 | 0 | if (embd == nullptr) { |
672 | 0 | throw std::runtime_error("no embeddings"); |
673 | 0 | } |
674 | | |
675 | 0 | if (i < 0) { |
676 | 0 | j = n_outputs + i; |
677 | 0 | if (j < 0) { |
678 | 0 | throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); |
679 | 0 | } |
680 | 0 | } else if ((size_t) i >= output_ids.size()) { |
681 | 0 | throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); |
682 | 0 | } else { |
683 | 0 | j = output_ids[i]; |
684 | 0 | } |
685 | | |
686 | 0 | if (j < 0) { |
687 | 0 | throw std::runtime_error(format("batch.logits[%d] != true", i)); |
688 | 0 | } |
689 | 0 | if (j >= n_outputs) { |
690 | | // This should not happen |
691 | 0 | throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); |
692 | 0 | } |
693 | | |
694 | 0 | return embd + j*model.hparams.n_embd; |
695 | 0 | } catch (const std::exception & err) { |
696 | 0 | LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); |
697 | | #ifndef NDEBUG |
698 | | GGML_ABORT("fatal error"); |
699 | | #else |
700 | 0 | return nullptr; |
701 | 0 | #endif |
702 | 0 | } |
703 | 0 | } |
704 | | |
705 | 0 | float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { |
706 | 0 | auto it = embd_seq.find(seq_id); |
707 | 0 | if (it == embd_seq.end()) { |
708 | 0 | return nullptr; |
709 | 0 | } |
710 | | |
711 | 0 | return it->second.data(); |
712 | 0 | } |
713 | | |
714 | | void llama_context::attach_threadpool( |
715 | | ggml_threadpool_t threadpool, |
716 | 0 | ggml_threadpool_t threadpool_batch) { |
717 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
718 | |
|
719 | 0 | this->threadpool = threadpool; |
720 | 0 | this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool; |
721 | 0 | } |
722 | | |
723 | 0 | void llama_context::detach_threadpool() { |
724 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
725 | |
|
726 | 0 | this->threadpool = nullptr; |
727 | 0 | this->threadpool_batch = nullptr; |
728 | 0 | } |
729 | | |
730 | 0 | void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) { |
731 | 0 | LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch); |
732 | |
|
733 | 0 | cparams.n_threads = n_threads; |
734 | 0 | cparams.n_threads_batch = n_threads_batch; |
735 | 0 | } |
736 | | |
737 | 0 | void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) { |
738 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
739 | |
|
740 | 0 | this->abort_callback = abort_callback; |
741 | 0 | this->abort_callback_data = abort_callback_data; |
742 | |
|
743 | 0 | for (auto & backend : backends) { |
744 | 0 | auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); |
745 | 0 | auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); |
746 | 0 | if (set_abort_callback_fn) { |
747 | 0 | set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); |
748 | 0 | } |
749 | 0 | } |
750 | 0 | } |
751 | | |
752 | 0 | void llama_context::set_embeddings(bool value) { |
753 | 0 | LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); |
754 | |
|
755 | 0 | cparams.embeddings = value; |
756 | 0 | } |
757 | | |
758 | 0 | void llama_context::set_causal_attn(bool value) { |
759 | 0 | LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); |
760 | |
|
761 | 0 | cparams.causal_attn = value; |
762 | 0 | } |
763 | | |
764 | 0 | void llama_context::set_warmup(bool value) { |
765 | 0 | LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); |
766 | |
|
767 | 0 | cparams.warmup = value; |
768 | 0 | } |
769 | | |
770 | | void llama_context::set_adapter_lora( |
771 | | llama_adapter_lora * adapter, |
772 | 0 | float scale) { |
773 | 0 | LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); |
774 | |
|
775 | 0 | loras[adapter] = scale; |
776 | 0 | } |
777 | | |
778 | | bool llama_context::rm_adapter_lora( |
779 | 0 | llama_adapter_lora * adapter) { |
780 | 0 | LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); |
781 | |
|
782 | 0 | auto pos = loras.find(adapter); |
783 | 0 | if (pos != loras.end()) { |
784 | 0 | loras.erase(pos); |
785 | 0 | return true; |
786 | 0 | } |
787 | | |
788 | 0 | return false; |
789 | 0 | } |
790 | | |
791 | 0 | void llama_context::clear_adapter_lora() { |
792 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
793 | |
|
794 | 0 | loras.clear(); |
795 | 0 | } |
796 | | |
797 | | bool llama_context::apply_adapter_cvec( |
798 | | const float * data, |
799 | | size_t len, |
800 | | int32_t n_embd, |
801 | | int32_t il_start, |
802 | 0 | int32_t il_end) { |
803 | 0 | LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); |
804 | |
|
805 | 0 | return cvec.apply(model, data, len, n_embd, il_start, il_end); |
806 | 0 | } |
807 | | |
808 | 0 | llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { |
809 | 0 | if (mctx && !mctx->apply()) { |
810 | 0 | LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); |
811 | 0 | ret = GGML_STATUS_FAILED; |
812 | 0 | return nullptr; |
813 | 0 | } |
814 | | |
815 | 0 | auto * res = gf_res_prev.get(); |
816 | 0 | auto * gf = res->get_gf(); |
817 | | |
818 | | // the new graph parameters |
819 | | // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters |
820 | 0 | const auto gparams = graph_params(res, ubatch, mctx, gtype); |
821 | |
|
822 | 0 | if (!graph_reuse_disable && res->can_reuse(gparams)) { |
823 | | //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); |
824 | |
|
825 | 0 | n_reused++; |
826 | 0 | } else { |
827 | 0 | res->reset(); |
828 | |
|
829 | 0 | ggml_backend_sched_reset(sched.get()); |
830 | 0 | ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); |
831 | | |
832 | | //const auto t_start_us = ggml_time_us(); |
833 | |
|
834 | 0 | gf = model.build_graph(gparams); |
835 | | |
836 | | //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); |
837 | |
|
838 | 0 | if (!gf) { |
839 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); |
840 | 0 | ret = GGML_STATUS_FAILED; |
841 | 0 | return nullptr; |
842 | 0 | } |
843 | | |
844 | 0 | if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { |
845 | 0 | LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); |
846 | 0 | ret = GGML_STATUS_ALLOC_FAILED; |
847 | 0 | return nullptr; |
848 | 0 | } |
849 | 0 | } |
850 | | |
851 | | // set the input data for the input tensors |
852 | 0 | { |
853 | | //const auto t_start_us = ggml_time_us(); |
854 | |
|
855 | 0 | res->set_inputs(&ubatch); |
856 | | |
857 | | //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); |
858 | 0 | } |
859 | |
|
860 | 0 | const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1); |
861 | 0 | if (status != GGML_STATUS_SUCCESS) { |
862 | 0 | LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); |
863 | 0 | ret = status; |
864 | 0 | return nullptr; |
865 | 0 | } |
866 | | |
867 | 0 | ret = GGML_STATUS_SUCCESS; |
868 | |
|
869 | 0 | return res; |
870 | 0 | } |
871 | | |
872 | 0 | int llama_context::encode(const llama_batch & batch_inp) { |
873 | 0 | GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT |
874 | |
|
875 | 0 | if (batch_inp.n_tokens == 0) { |
876 | 0 | LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); |
877 | 0 | return -1; |
878 | 0 | } |
879 | | |
880 | 0 | const auto & hparams = model.hparams; |
881 | |
|
882 | 0 | const int64_t n_embd = hparams.n_embd_inp(); |
883 | 0 | const int64_t n_vocab = model.vocab.n_tokens(); |
884 | | |
885 | | // note: during encode, we always pass the full sequence starting from pos = 0 |
886 | 0 | if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { |
887 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); |
888 | 0 | return -1; |
889 | 0 | } |
890 | | |
891 | 0 | const uint32_t n_tokens = balloc->get_n_tokens(); |
892 | | |
893 | | // [TAG_NO_CACHE_PAD] |
894 | | // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true |
895 | 0 | const llama_ubatch ubatch = balloc->split_simple(n_tokens); |
896 | | |
897 | | // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot |
898 | 0 | GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens"); |
899 | |
|
900 | 0 | if (t_compute_start_us == 0) { |
901 | 0 | t_compute_start_us = ggml_time_us(); |
902 | 0 | } |
903 | | |
904 | | // TODO: this clear of the buffer can easily be forgotten - need something better |
905 | 0 | embd_seq.clear(); |
906 | |
|
907 | 0 | n_queued_tokens += n_tokens; |
908 | | |
909 | | // reserve output buffer |
910 | 0 | if (output_reserve(n_tokens) < n_tokens) { |
911 | 0 | LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); |
912 | 0 | return -2; |
913 | 0 | }; |
914 | |
|
915 | 0 | for (uint32_t i = 0; i < n_tokens; ++i) { |
916 | 0 | output_ids[i] = i; |
917 | 0 | } |
918 | |
|
919 | 0 | n_outputs = n_tokens; |
920 | |
|
921 | 0 | const auto causal_attn_org = cparams.causal_attn; |
922 | | |
923 | | // always use non-causal attention for encoder graphs |
924 | | // TODO: this is a tmp solution until we have a proper way to support enc-dec models |
925 | | // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 |
926 | 0 | cparams.causal_attn = false; |
927 | |
|
928 | 0 | ggml_status status; |
929 | 0 | const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); |
930 | |
|
931 | 0 | cparams.causal_attn = causal_attn_org; |
932 | |
|
933 | 0 | if (!res) { |
934 | 0 | switch (status) { |
935 | 0 | case GGML_STATUS_ABORTED: return 2; |
936 | 0 | case GGML_STATUS_ALLOC_FAILED: return -2; |
937 | 0 | case GGML_STATUS_FAILED: return -3; |
938 | 0 | case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); |
939 | 0 | } |
940 | 0 | } |
941 | | |
942 | 0 | auto * t_logits = res->get_logits(); |
943 | 0 | auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); |
944 | | |
945 | | // extract logits |
946 | 0 | if (logits && t_logits) { |
947 | 0 | ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); |
948 | 0 | GGML_ASSERT(backend_res != nullptr); |
949 | 0 | GGML_ASSERT(logits != nullptr); |
950 | |
|
951 | 0 | ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); |
952 | 0 | } |
953 | | |
954 | | // extract embeddings |
955 | 0 | if (embd && t_embd) { |
956 | 0 | ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); |
957 | 0 | GGML_ASSERT(backend_embd != nullptr); |
958 | |
|
959 | 0 | switch (cparams.pooling_type) { |
960 | 0 | case LLAMA_POOLING_TYPE_NONE: |
961 | 0 | { |
962 | | // extract token embeddings |
963 | 0 | GGML_ASSERT(embd != nullptr); |
964 | |
|
965 | 0 | GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); |
966 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float)); |
967 | 0 | } break; |
968 | 0 | case LLAMA_POOLING_TYPE_MEAN: |
969 | 0 | case LLAMA_POOLING_TYPE_CLS: |
970 | 0 | case LLAMA_POOLING_TYPE_LAST: |
971 | 0 | { |
972 | | // extract sequence embeddings |
973 | 0 | auto & embd_seq_out = embd_seq; |
974 | |
|
975 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
976 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
977 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
978 | |
|
979 | 0 | embd_seq_out[seq_id].resize(n_embd); |
980 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); |
981 | 0 | } |
982 | 0 | } break; |
983 | 0 | case LLAMA_POOLING_TYPE_RANK: |
984 | 0 | { |
985 | | // extract the rerank score - n_cls_out floats per sequence |
986 | 0 | auto & embd_seq_out = embd_seq; |
987 | |
|
988 | 0 | const uint32_t n_cls_out = hparams.n_cls_out; |
989 | |
|
990 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
991 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
992 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
993 | |
|
994 | 0 | embd_seq_out[seq_id].resize(n_cls_out); |
995 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); |
996 | 0 | } |
997 | 0 | } break; |
998 | 0 | case LLAMA_POOLING_TYPE_UNSPECIFIED: |
999 | 0 | { |
1000 | 0 | GGML_ABORT("unknown pooling type"); |
1001 | 0 | } |
1002 | 0 | } |
1003 | 0 | } |
1004 | | |
1005 | | // TODO: hacky solution |
1006 | 0 | if (model.arch == LLM_ARCH_T5 && t_embd) { |
1007 | | //cross.t_embd = t_embd; |
1008 | |
|
1009 | 0 | synchronize(); |
1010 | |
|
1011 | 0 | cross.n_embd = t_embd->ne[0]; |
1012 | 0 | cross.n_enc = t_embd->ne[1]; |
1013 | 0 | cross.v_embd.resize(cross.n_embd*cross.n_enc); |
1014 | 0 | memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); |
1015 | |
|
1016 | 0 | const auto & batch = balloc->get_batch(); |
1017 | | |
1018 | | // remember the sequence ids used during the encoding - needed for cross attention later |
1019 | 0 | cross.seq_ids_enc.resize(n_tokens); |
1020 | 0 | for (uint32_t i = 0; i < n_tokens; i++) { |
1021 | 0 | cross.seq_ids_enc[i].clear(); |
1022 | |
|
1023 | 0 | for (int s = 0; s < batch.n_seq_id[i]; s++) { |
1024 | 0 | const llama_seq_id seq_id = batch.seq_id[i][s]; |
1025 | |
|
1026 | 0 | cross.seq_ids_enc[i].insert(seq_id); |
1027 | 0 | } |
1028 | 0 | } |
1029 | 0 | } |
1030 | |
|
1031 | 0 | return 0; |
1032 | 0 | } |
1033 | | |
1034 | 0 | int llama_context::decode(const llama_batch & batch_inp) { |
1035 | 0 | GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT |
1036 | |
|
1037 | 0 | if (!memory) { |
1038 | 0 | LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); |
1039 | 0 | return encode(batch_inp); |
1040 | 0 | } |
1041 | | |
1042 | 0 | if (batch_inp.n_tokens == 0) { |
1043 | 0 | LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); |
1044 | 0 | return -1; |
1045 | 0 | } |
1046 | | |
1047 | 0 | const auto & vocab = model.vocab; |
1048 | 0 | const auto & hparams = model.hparams; |
1049 | |
|
1050 | 0 | const int64_t n_vocab = vocab.n_tokens(); |
1051 | 0 | const int64_t n_embd = hparams.n_embd_inp(); |
1052 | | |
1053 | | // when computing embeddings, all tokens are output |
1054 | 0 | const bool output_all = cparams.embeddings; |
1055 | |
|
1056 | 0 | if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) { |
1057 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); |
1058 | 0 | return -1; |
1059 | 0 | } |
1060 | | |
1061 | 0 | const uint32_t n_tokens_all = balloc->get_n_tokens(); |
1062 | 0 | const uint32_t n_outputs_all = balloc->get_n_outputs(); |
1063 | |
|
1064 | 0 | if (output_all) { |
1065 | | // require that all tokens are output |
1066 | 0 | if (n_outputs_all != n_tokens_all) { |
1067 | 0 | LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n", |
1068 | 0 | __func__, n_outputs_all, n_tokens_all); |
1069 | 0 | return -1; |
1070 | 0 | } |
1071 | 0 | } |
1072 | | |
1073 | 0 | GGML_ASSERT(n_tokens_all <= cparams.n_batch); |
1074 | |
|
1075 | 0 | GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); |
1076 | |
|
1077 | 0 | if (t_compute_start_us == 0) { |
1078 | 0 | t_compute_start_us = ggml_time_us(); |
1079 | 0 | } |
1080 | 0 | n_queued_tokens += n_tokens_all; |
1081 | | |
1082 | | // TODO: this clear of the buffer can easily be forgotten - need something better |
1083 | 0 | embd_seq.clear(); |
1084 | 0 | output_swaps.clear(); |
1085 | |
|
1086 | 0 | bool did_optimize = false; |
1087 | | |
1088 | | // handle any pending shifts/copies |
1089 | 0 | memory_update(false); |
1090 | |
|
1091 | 0 | llama_memory_context_ptr mctx; |
1092 | |
|
1093 | 0 | while (true) { |
1094 | 0 | mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); |
1095 | 0 | if (!mctx) { |
1096 | 0 | return -2; |
1097 | 0 | } |
1098 | | |
1099 | 0 | switch (mctx->get_status()) { |
1100 | 0 | case LLAMA_MEMORY_STATUS_SUCCESS: |
1101 | 0 | { |
1102 | 0 | } break; |
1103 | 0 | case LLAMA_MEMORY_STATUS_NO_UPDATE: |
1104 | 0 | { |
1105 | 0 | LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status()); |
1106 | |
|
1107 | 0 | return -2; |
1108 | 0 | } |
1109 | 0 | case LLAMA_MEMORY_STATUS_FAILED_PREPARE: |
1110 | 0 | { |
1111 | 0 | if (!did_optimize) { |
1112 | 0 | did_optimize = true; |
1113 | |
|
1114 | 0 | if (memory_update(true)) { |
1115 | 0 | LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens()); |
1116 | |
|
1117 | 0 | continue; |
1118 | 0 | } |
1119 | 0 | } |
1120 | | |
1121 | 0 | LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens()); |
1122 | |
|
1123 | 0 | return 1; |
1124 | 0 | } |
1125 | 0 | case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: |
1126 | 0 | { |
1127 | 0 | LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens()); |
1128 | |
|
1129 | 0 | return -2; |
1130 | 0 | } |
1131 | 0 | } |
1132 | | |
1133 | 0 | break; |
1134 | 0 | } |
1135 | | |
1136 | | // reserve output buffer |
1137 | 0 | if (output_reserve(n_outputs_all) < n_outputs_all) { |
1138 | 0 | LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); |
1139 | 0 | return -2; |
1140 | 0 | }; |
1141 | |
|
1142 | 0 | int64_t n_outputs_prev = 0; |
1143 | |
|
1144 | 0 | do { |
1145 | 0 | const auto & ubatch = mctx->get_ubatch(); |
1146 | | |
1147 | | // count the outputs in this ubatch |
1148 | 0 | { |
1149 | 0 | int32_t n_outputs_new = 0; |
1150 | |
|
1151 | 0 | if (n_outputs_all == n_tokens_all) { |
1152 | 0 | n_outputs_new = ubatch.n_tokens; |
1153 | 0 | } else { |
1154 | 0 | for (uint32_t i = 0; i < ubatch.n_tokens; i++) { |
1155 | 0 | n_outputs_new += (int32_t) (ubatch.output[i] != 0); |
1156 | 0 | } |
1157 | 0 | } |
1158 | | |
1159 | | // needs to happen before the graph is built |
1160 | 0 | n_outputs = n_outputs_new; |
1161 | 0 | } |
1162 | |
|
1163 | 0 | ggml_status status; |
1164 | 0 | const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); |
1165 | |
|
1166 | 0 | if (!res) { |
1167 | | // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module |
1168 | 0 | llama_pos pos_min[LLAMA_MAX_SEQ]; |
1169 | 0 | for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
1170 | 0 | pos_min[s] = std::numeric_limits<llama_pos>::max(); |
1171 | 0 | } |
1172 | |
|
1173 | 0 | for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { |
1174 | 0 | const auto & seq_id = ubatch.seq_id[i][0]; |
1175 | |
|
1176 | 0 | pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); |
1177 | 0 | } |
1178 | |
|
1179 | 0 | for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
1180 | 0 | if (pos_min[s] == std::numeric_limits<llama_pos>::max()) { |
1181 | 0 | continue; |
1182 | 0 | } |
1183 | | |
1184 | 0 | LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); |
1185 | |
|
1186 | 0 | memory->seq_rm(s, pos_min[s], -1); |
1187 | 0 | } |
1188 | |
|
1189 | 0 | switch (status) { |
1190 | 0 | case GGML_STATUS_ABORTED: return 2; |
1191 | 0 | case GGML_STATUS_ALLOC_FAILED: return -2; |
1192 | 0 | case GGML_STATUS_FAILED: return -3; |
1193 | 0 | case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); |
1194 | 0 | } |
1195 | 0 | } |
1196 | | |
1197 | | // plot the computation graph in dot format (for debugging purposes) |
1198 | | //if (n_past%100 == 0) { |
1199 | | // ggml_graph_dump_dot(gf, NULL, "llama.dot"); |
1200 | | //} |
1201 | | |
1202 | 0 | auto * t_logits = res->get_logits(); |
1203 | 0 | auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; |
1204 | |
|
1205 | 0 | if (t_embd && res->get_embd_pooled()) { |
1206 | 0 | t_embd = res->get_embd_pooled(); |
1207 | 0 | } |
1208 | | |
1209 | | // extract logits |
1210 | 0 | if (t_logits && n_outputs > 0) { |
1211 | 0 | ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); |
1212 | 0 | GGML_ASSERT(backend_res != nullptr); |
1213 | 0 | GGML_ASSERT(logits != nullptr); |
1214 | |
|
1215 | 0 | float * logits_out = logits + n_outputs_prev*n_vocab; |
1216 | |
|
1217 | 0 | if (n_outputs) { |
1218 | 0 | GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); |
1219 | 0 | GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); |
1220 | 0 | ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); |
1221 | 0 | } |
1222 | 0 | } |
1223 | | |
1224 | | // extract embeddings |
1225 | 0 | if (t_embd && n_outputs > 0) { |
1226 | 0 | ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); |
1227 | 0 | GGML_ASSERT(backend_embd != nullptr); |
1228 | |
|
1229 | 0 | switch (cparams.pooling_type) { |
1230 | 0 | case LLAMA_POOLING_TYPE_NONE: |
1231 | 0 | { |
1232 | | // extract token embeddings |
1233 | 0 | GGML_ASSERT(embd != nullptr); |
1234 | 0 | float * embd_out = embd + n_outputs_prev*n_embd; |
1235 | |
|
1236 | 0 | if (n_outputs) { |
1237 | 0 | GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); |
1238 | 0 | GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); |
1239 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); |
1240 | 0 | } |
1241 | 0 | } break; |
1242 | 0 | case LLAMA_POOLING_TYPE_MEAN: |
1243 | 0 | case LLAMA_POOLING_TYPE_CLS: |
1244 | 0 | case LLAMA_POOLING_TYPE_LAST: |
1245 | 0 | { |
1246 | | // extract sequence embeddings (cleared before processing each batch) |
1247 | 0 | auto & embd_seq_out = embd_seq; |
1248 | |
|
1249 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1250 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1251 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1252 | |
|
1253 | 0 | embd_seq_out[seq_id].resize(n_embd); |
1254 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); |
1255 | 0 | } |
1256 | 0 | } break; |
1257 | 0 | case LLAMA_POOLING_TYPE_RANK: |
1258 | 0 | { |
1259 | | // extract the rerank score - n_cls_out floats per sequence |
1260 | 0 | auto & embd_seq_out = embd_seq; |
1261 | |
|
1262 | 0 | const uint32_t n_cls_out = hparams.n_cls_out; |
1263 | |
|
1264 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1265 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1266 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1267 | |
|
1268 | 0 | embd_seq_out[seq_id].resize(n_cls_out); |
1269 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); |
1270 | 0 | } |
1271 | 0 | } break; |
1272 | 0 | case LLAMA_POOLING_TYPE_UNSPECIFIED: |
1273 | 0 | { |
1274 | 0 | GGML_ABORT("unknown pooling type"); |
1275 | 0 | } |
1276 | 0 | } |
1277 | 0 | } |
1278 | | |
1279 | 0 | n_outputs_prev += n_outputs; |
1280 | 0 | } while (mctx->next()); |
1281 | | |
1282 | | // set to total number of outputs in the batch, for use in llama_get_logits_ith |
1283 | 0 | n_outputs = n_outputs_all; |
1284 | | |
1285 | | // set output mappings |
1286 | 0 | if (n_outputs > 0) { |
1287 | 0 | bool sorted_output = true; |
1288 | |
|
1289 | 0 | auto & out_ids = balloc->get_out_ids(); |
1290 | |
|
1291 | 0 | GGML_ASSERT(out_ids.size() == (size_t) n_outputs); |
1292 | |
|
1293 | 0 | for (int64_t i = 0; i < n_outputs; ++i) { |
1294 | 0 | int64_t out_id = out_ids[i]; |
1295 | 0 | output_ids[out_id] = i; |
1296 | 0 | if (out_id != i) { |
1297 | 0 | sorted_output = false; |
1298 | 0 | } |
1299 | 0 | } |
1300 | | |
1301 | | // make the outputs have the same order they had in the user-provided batch |
1302 | | // note: this is mostly relevant for recurrent models atm |
1303 | 0 | if (!sorted_output && n_outputs > 1) { |
1304 | 0 | GGML_ASSERT((size_t) n_outputs == out_ids.size()); |
1305 | | |
1306 | | // TODO: is there something more efficient which also minimizes swaps? |
1307 | | // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) |
1308 | 0 | for (uint32_t i = 0; i < n_outputs - 1; ++i) { |
1309 | 0 | uint32_t j_min = i; |
1310 | 0 | for (uint32_t j = i + 1; j < n_outputs; ++j) { |
1311 | 0 | if (out_ids[j] < out_ids[j_min]) { |
1312 | 0 | j_min = j; |
1313 | 0 | } |
1314 | 0 | } |
1315 | 0 | if (j_min == i) { |
1316 | 0 | continue; |
1317 | 0 | } |
1318 | 0 | std::swap(out_ids[i], out_ids[j_min]); |
1319 | | |
1320 | | // remember the swaps and apply them lazily upon logits/embeddings access |
1321 | 0 | output_swaps.push_back({ i, j_min }); |
1322 | 0 | } |
1323 | |
|
1324 | 0 | std::fill(output_ids.begin(), output_ids.end(), -1); |
1325 | |
|
1326 | 0 | for (uint32_t i = 0; i < n_outputs; ++i) { |
1327 | 0 | output_ids[out_ids[i]] = i; |
1328 | 0 | } |
1329 | 0 | } |
1330 | 0 | } |
1331 | | |
1332 | | // wait for the computation to finish (automatically done when obtaining the model output) |
1333 | | //synchronize(); |
1334 | |
|
1335 | 0 | return 0; |
1336 | 0 | } |
1337 | | |
1338 | | // |
1339 | | // output |
1340 | | // |
1341 | | |
1342 | 0 | uint32_t llama_context::output_reserve(int32_t n_outputs) { |
1343 | 0 | const auto & hparams = model.hparams; |
1344 | 0 | const auto & vocab = model.vocab; |
1345 | |
|
1346 | 0 | const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max()); |
1347 | |
|
1348 | 0 | const auto n_batch = cparams.n_batch; |
1349 | 0 | const auto n_vocab = vocab.n_tokens(); |
1350 | 0 | const auto n_embd = hparams.n_embd; |
1351 | |
|
1352 | 0 | bool has_logits = true; |
1353 | 0 | bool has_embd = cparams.embeddings; |
1354 | | |
1355 | | // TODO: hacky enc-dec support |
1356 | 0 | if (model.arch == LLM_ARCH_T5) { |
1357 | 0 | has_logits = true; |
1358 | 0 | has_embd = true; |
1359 | 0 | } |
1360 | |
|
1361 | 0 | logits_size = has_logits ? n_vocab*n_outputs_max : 0; |
1362 | 0 | embd_size = has_embd ? n_embd*n_outputs_max : 0; |
1363 | |
|
1364 | 0 | if (output_ids.empty()) { |
1365 | | // init, never resized afterwards |
1366 | 0 | output_ids.resize(n_batch); |
1367 | 0 | } |
1368 | |
|
1369 | 0 | const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; |
1370 | 0 | const size_t new_size = (logits_size + embd_size) * sizeof(float); |
1371 | | |
1372 | | // alloc only when more than the current capacity is required |
1373 | | // TODO: also consider shrinking the buffer |
1374 | 0 | if (!buf_output || prev_size < new_size) { |
1375 | 0 | if (buf_output) { |
1376 | | #ifndef NDEBUG |
1377 | | // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) |
1378 | | LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); |
1379 | | #endif |
1380 | 0 | synchronize(); |
1381 | 0 | buf_output = nullptr; |
1382 | 0 | logits = nullptr; |
1383 | 0 | embd = nullptr; |
1384 | 0 | } |
1385 | |
|
1386 | 0 | auto * buft = ggml_backend_cpu_buffer_type(); |
1387 | | // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory |
1388 | 0 | auto * output_dev = model.dev_output(); |
1389 | 0 | auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; |
1390 | 0 | if (output_dev_host_buft) { |
1391 | 0 | buft = output_dev_host_buft; |
1392 | 0 | } |
1393 | 0 | buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size)); |
1394 | 0 | if (buf_output == nullptr) { |
1395 | 0 | LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); |
1396 | 0 | return 0; |
1397 | 0 | } |
1398 | 0 | } |
1399 | | |
1400 | 0 | float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); |
1401 | |
|
1402 | 0 | logits = has_logits ? output_base : nullptr; |
1403 | 0 | embd = has_embd ? output_base + logits_size : nullptr; |
1404 | | |
1405 | | // set all ids as invalid (negative) |
1406 | 0 | std::fill(output_ids.begin(), output_ids.end(), -1); |
1407 | |
|
1408 | 0 | this->n_outputs = 0; |
1409 | |
|
1410 | 0 | return n_outputs_max; |
1411 | 0 | } |
1412 | | |
1413 | 0 | void llama_context::output_reorder() { |
1414 | 0 | const uint64_t n_vocab = model.vocab.n_tokens(); |
1415 | 0 | const uint64_t n_embd = model.hparams.n_embd; |
1416 | |
|
1417 | 0 | for (size_t s = 0; s < output_swaps.size(); ++s) { |
1418 | 0 | const uint64_t i0 = output_swaps[s].i0; |
1419 | 0 | const uint64_t i1 = output_swaps[s].i1; |
1420 | |
|
1421 | 0 | if (logits_size > 0) { |
1422 | 0 | for (uint64_t k = 0; k < n_vocab; k++) { |
1423 | 0 | std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); |
1424 | 0 | } |
1425 | 0 | } |
1426 | |
|
1427 | 0 | if (embd_size > 0) { |
1428 | 0 | for (uint64_t k = 0; k < n_embd; k++) { |
1429 | 0 | std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); |
1430 | 0 | } |
1431 | 0 | } |
1432 | 0 | } |
1433 | |
|
1434 | 0 | output_swaps.clear(); |
1435 | 0 | } |
1436 | | |
1437 | | // |
1438 | | // graph |
1439 | | // |
1440 | | |
1441 | 0 | uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { |
1442 | 0 | if (model.arch == LLM_ARCH_QWEN3NEXT) { |
1443 | 0 | return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors()); |
1444 | 0 | } |
1445 | 0 | return std::max<uint32_t>(1024u, 8u*model.n_tensors()); |
1446 | 0 | } |
1447 | | |
1448 | 0 | llm_graph_result * llama_context::get_gf_res_reserve() const { |
1449 | 0 | return static_cast<llm_graph_result *>(gf_res_reserve.get()); |
1450 | 0 | } |
1451 | | |
1452 | | ggml_cgraph * llama_context::graph_reserve( |
1453 | 0 | uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only, size_t * sizes) { |
1454 | 0 | LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); |
1455 | 0 | GGML_ASSERT(n_outputs >= 1); |
1456 | |
|
1457 | 0 | if (n_tokens % n_seqs != 0) { |
1458 | 0 | n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs |
1459 | 0 | n_outputs = std::min(n_outputs, n_tokens); |
1460 | |
|
1461 | 0 | LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); |
1462 | 0 | } |
1463 | |
|
1464 | 0 | ggml_backend_sched_reset(sched.get()); |
1465 | | |
1466 | | // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that |
1467 | 0 | gf_res_prev->reset(); |
1468 | | |
1469 | | // store the n_outputs as it is, and restore it afterwards |
1470 | | // TODO: not sure if needed, might simplify in the future by removing this |
1471 | 0 | const auto save_n_outputs = this->n_outputs; |
1472 | |
|
1473 | 0 | this->n_outputs = n_outputs; |
1474 | |
|
1475 | 0 | llama_batch_allocr balloc(model.hparams.n_pos_per_embd()); |
1476 | 0 | llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); |
1477 | |
|
1478 | 0 | auto * res = gf_res_reserve.get(); |
1479 | |
|
1480 | 0 | const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); |
1481 | |
|
1482 | 0 | res->reset(); |
1483 | |
|
1484 | 0 | auto * gf = model.build_graph(gparams); |
1485 | |
|
1486 | 0 | this->n_outputs = save_n_outputs; |
1487 | | |
1488 | | // initialize scheduler with the specified graph |
1489 | 0 | if (split_only) { |
1490 | 0 | if (sizes) { |
1491 | 0 | ggml_backend_sched_reserve_size(sched.get(), gf, sizes); |
1492 | 0 | } else { |
1493 | 0 | ggml_backend_sched_split_graph(sched.get(), gf); |
1494 | 0 | } |
1495 | 0 | } else if (!ggml_backend_sched_reserve(sched.get(), gf)) { |
1496 | 0 | GGML_ASSERT(!sizes); |
1497 | 0 | LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); |
1498 | 0 | return nullptr; |
1499 | 0 | } |
1500 | | |
1501 | 0 | return gf; |
1502 | 0 | } |
1503 | | |
1504 | | llm_graph_params llama_context::graph_params( |
1505 | | llm_graph_result * res, |
1506 | | const llama_ubatch & ubatch, |
1507 | | const llama_memory_context_i * mctx, |
1508 | 0 | llm_graph_type gtype) const { |
1509 | 0 | return { |
1510 | 0 | /*.arch =*/ model.arch, |
1511 | 0 | /*.hparams =*/ model.hparams, |
1512 | 0 | /*.cparams =*/ cparams, |
1513 | 0 | /*.ubatch =*/ ubatch, |
1514 | 0 | /*.gtype =*/ gtype, |
1515 | 0 | /*.sched =*/ sched.get(), |
1516 | 0 | /*.backend_cpu =*/ backend_cpu, |
1517 | 0 | /*.cvec =*/ &cvec, |
1518 | 0 | /*.loras =*/ &loras, |
1519 | 0 | /*.mctx =*/ mctx, |
1520 | 0 | /*.cross =*/ &cross, |
1521 | 0 | /*.n_outputs =*/ n_outputs, |
1522 | 0 | /*.cb =*/ graph_get_cb(), |
1523 | 0 | /*.res =*/ res, |
1524 | 0 | }; |
1525 | 0 | } |
1526 | | |
1527 | | ggml_status llama_context::graph_compute( |
1528 | | ggml_cgraph * gf, |
1529 | 0 | bool batched) { |
1530 | 0 | int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; |
1531 | 0 | ggml_threadpool_t tp = batched ? threadpool_batch : threadpool; |
1532 | |
|
1533 | 0 | if (backend_cpu != nullptr) { |
1534 | 0 | auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu)); |
1535 | 0 | auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"); |
1536 | 0 | if (set_threadpool_fn) { |
1537 | 0 | set_threadpool_fn(backend_cpu, tp); |
1538 | 0 | } |
1539 | 0 | } |
1540 | | |
1541 | | // set the number of threads for all the backends |
1542 | 0 | for (const auto & set_n_threads_fn : set_n_threads_fns) { |
1543 | 0 | set_n_threads_fn.second(set_n_threads_fn.first, n_threads); |
1544 | 0 | } |
1545 | |
|
1546 | 0 | auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf); |
1547 | 0 | if (status != GGML_STATUS_SUCCESS) { |
1548 | 0 | LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); |
1549 | 0 | } |
1550 | | |
1551 | | // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched)); |
1552 | |
|
1553 | 0 | return status; |
1554 | 0 | } |
1555 | | |
1556 | 0 | llm_graph_cb llama_context::graph_get_cb() const { |
1557 | 0 | return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { |
1558 | 0 | if (il >= 0) { |
1559 | 0 | ggml_format_name(cur, "%s-%d", name, il); |
1560 | 0 | } else { |
1561 | 0 | ggml_set_name(cur, name); |
1562 | 0 | } |
1563 | |
|
1564 | 0 | if (!cparams.offload_kqv) { |
1565 | 0 | if (strcmp(name, "kqv_merged_cont") == 0) { |
1566 | | // all nodes between the KV store and the attention output are run on the CPU |
1567 | 0 | ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); |
1568 | 0 | } |
1569 | 0 | } |
1570 | | |
1571 | | // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends |
1572 | | // FIXME: fix in ggml_backend_sched |
1573 | 0 | const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; |
1574 | 0 | if (ubatch.n_tokens < 32 || full_offload) { |
1575 | 0 | if (il != -1 && strcmp(name, "norm") == 0) { |
1576 | 0 | const auto & dev_layer = model.dev_layer(il); |
1577 | 0 | for (const auto & backend : backends) { |
1578 | 0 | if (ggml_backend_get_device(backend.get()) == dev_layer) { |
1579 | 0 | if (ggml_backend_supports_op(backend.get(), cur)) { |
1580 | 0 | ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get()); |
1581 | 0 | } |
1582 | 0 | } |
1583 | 0 | } |
1584 | 0 | } |
1585 | 0 | } |
1586 | 0 | }; |
1587 | 0 | } |
1588 | | |
1589 | | // |
1590 | | // state save/load |
1591 | | // |
1592 | | |
1593 | | class llama_io_write_dummy : public llama_io_write_i { |
1594 | | public: |
1595 | 0 | llama_io_write_dummy() = default; |
1596 | | |
1597 | 0 | void write(const void * /* src */, size_t size) override { |
1598 | 0 | size_written += size; |
1599 | 0 | } |
1600 | | |
1601 | 0 | void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { |
1602 | 0 | size_written += size; |
1603 | 0 | } |
1604 | | |
1605 | 0 | size_t n_bytes() override { |
1606 | 0 | return size_written; |
1607 | 0 | } |
1608 | | |
1609 | | private: |
1610 | | size_t size_written = 0; |
1611 | | }; |
1612 | | |
1613 | | class llama_io_write_buffer : public llama_io_write_i { |
1614 | | public: |
1615 | | llama_io_write_buffer( |
1616 | 0 | uint8_t * p, size_t len) : ptr(p), buf_size(len) {} |
1617 | | |
1618 | 0 | void write(const void * src, size_t size) override { |
1619 | 0 | if (size > buf_size) { |
1620 | 0 | throw std::runtime_error("unexpectedly reached end of buffer"); |
1621 | 0 | } |
1622 | 0 | memcpy(ptr, src, size); |
1623 | 0 | ptr += size; |
1624 | 0 | size_written += size; |
1625 | 0 | buf_size -= size; |
1626 | 0 | } |
1627 | | |
1628 | 0 | void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { |
1629 | 0 | if (size > buf_size) { |
1630 | 0 | throw std::runtime_error("unexpectedly reached end of buffer"); |
1631 | 0 | } |
1632 | 0 | ggml_backend_tensor_get(tensor, ptr, offset, size); |
1633 | 0 | ptr += size; |
1634 | 0 | size_written += size; |
1635 | 0 | buf_size -= size; |
1636 | 0 | } |
1637 | | |
1638 | 0 | size_t n_bytes() override { |
1639 | 0 | return size_written; |
1640 | 0 | } |
1641 | | |
1642 | | private: |
1643 | | uint8_t * ptr; |
1644 | | size_t buf_size = 0; |
1645 | | size_t size_written = 0; |
1646 | | }; |
1647 | | |
1648 | | class llama_io_read_buffer : public llama_io_read_i { |
1649 | | public: |
1650 | 0 | llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} |
1651 | | |
1652 | 0 | const uint8_t * read(size_t size) override { |
1653 | 0 | const uint8_t * base_ptr = ptr; |
1654 | 0 | if (size > buf_size) { |
1655 | 0 | throw std::runtime_error("unexpectedly reached end of buffer"); |
1656 | 0 | } |
1657 | 0 | ptr += size; |
1658 | 0 | size_read += size; |
1659 | 0 | buf_size -= size; |
1660 | 0 | return base_ptr; |
1661 | 0 | } |
1662 | | |
1663 | 0 | void read_to(void * dst, size_t size) override { |
1664 | 0 | memcpy(dst, read(size), size); |
1665 | 0 | } |
1666 | | |
1667 | 0 | size_t n_bytes() override { |
1668 | 0 | return size_read; |
1669 | 0 | } |
1670 | | |
1671 | | private: |
1672 | | const uint8_t * ptr; |
1673 | | size_t buf_size = 0; |
1674 | | size_t size_read = 0; |
1675 | | }; |
1676 | | |
1677 | | class llama_io_write_file : public llama_io_write_i { |
1678 | | public: |
1679 | 0 | llama_io_write_file(llama_file * f) : file(f) {} |
1680 | | |
1681 | 0 | void write(const void * src, size_t size) override { |
1682 | 0 | file->write_raw(src, size); |
1683 | 0 | size_written += size; |
1684 | 0 | } |
1685 | | |
1686 | 0 | void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { |
1687 | 0 | temp_buffer.resize(size); |
1688 | 0 | ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); |
1689 | 0 | write(temp_buffer.data(), temp_buffer.size()); |
1690 | 0 | } |
1691 | | |
1692 | 0 | size_t n_bytes() override { |
1693 | 0 | return size_written; |
1694 | 0 | } |
1695 | | |
1696 | | private: |
1697 | | llama_file * file; |
1698 | | size_t size_written = 0; |
1699 | | std::vector<uint8_t> temp_buffer; |
1700 | | }; |
1701 | | |
1702 | | class llama_io_read_file : public llama_io_read_i { |
1703 | | public: |
1704 | 0 | llama_io_read_file(llama_file * f) : file(f) {} |
1705 | | |
1706 | 0 | void read_to(void * dst, size_t size) override { |
1707 | 0 | file->read_raw(dst, size); |
1708 | 0 | size_read += size; |
1709 | 0 | } |
1710 | | |
1711 | 0 | const uint8_t * read(size_t size) override { |
1712 | 0 | temp_buffer.resize(size); |
1713 | 0 | read_to(temp_buffer.data(), size); |
1714 | 0 | return temp_buffer.data(); |
1715 | 0 | } |
1716 | | |
1717 | 0 | size_t n_bytes() override { |
1718 | 0 | return size_read; |
1719 | 0 | } |
1720 | | |
1721 | | private: |
1722 | | llama_file * file; |
1723 | | size_t size_read = 0; |
1724 | | std::vector<uint8_t> temp_buffer; |
1725 | | }; |
1726 | | |
1727 | 0 | size_t llama_context::state_get_size() { |
1728 | 0 | llama_io_write_dummy io; |
1729 | 0 | try { |
1730 | 0 | return state_write_data(io); |
1731 | 0 | } catch (const std::exception & err) { |
1732 | 0 | LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); |
1733 | 0 | return 0; |
1734 | 0 | } |
1735 | 0 | } |
1736 | | |
1737 | 0 | size_t llama_context::state_get_data(uint8_t * dst, size_t size) { |
1738 | 0 | llama_io_write_buffer io(dst, size); |
1739 | 0 | try { |
1740 | 0 | return state_write_data(io); |
1741 | 0 | } catch (const std::exception & err) { |
1742 | 0 | LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); |
1743 | 0 | return 0; |
1744 | 0 | } |
1745 | 0 | } |
1746 | | |
1747 | 0 | size_t llama_context::state_set_data(const uint8_t * src, size_t size) { |
1748 | 0 | llama_io_read_buffer io(src, size); |
1749 | 0 | try { |
1750 | 0 | return state_read_data(io); |
1751 | 0 | } catch (const std::exception & err) { |
1752 | 0 | LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); |
1753 | 0 | return 0; |
1754 | 0 | } |
1755 | 0 | } |
1756 | | |
1757 | 0 | size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) { |
1758 | 0 | llama_io_write_dummy io; |
1759 | 0 | try { |
1760 | 0 | return state_seq_write_data(io, seq_id, flags); |
1761 | 0 | } catch (const std::exception & err) { |
1762 | 0 | LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); |
1763 | 0 | return 0; |
1764 | 0 | } |
1765 | 0 | } |
1766 | | |
1767 | 0 | size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) { |
1768 | 0 | llama_io_write_buffer io(dst, size); |
1769 | 0 | try { |
1770 | 0 | return state_seq_write_data(io, seq_id, flags); |
1771 | 0 | } catch (const std::exception & err) { |
1772 | 0 | LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); |
1773 | 0 | return 0; |
1774 | 0 | } |
1775 | 0 | } |
1776 | | |
1777 | 0 | size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) { |
1778 | 0 | llama_io_read_buffer io(src, size); |
1779 | 0 | try { |
1780 | 0 | return state_seq_read_data(io, seq_id, flags); |
1781 | 0 | } catch (const std::exception & err) { |
1782 | 0 | LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); |
1783 | 0 | return 0; |
1784 | 0 | } |
1785 | 0 | } |
1786 | | |
1787 | 0 | bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
1788 | 0 | llama_file file(filepath, "rb"); |
1789 | | |
1790 | | // sanity checks |
1791 | 0 | { |
1792 | 0 | const uint32_t magic = file.read_u32(); |
1793 | 0 | const uint32_t version = file.read_u32(); |
1794 | |
|
1795 | 0 | if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { |
1796 | 0 | LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); |
1797 | 0 | return false; |
1798 | 0 | } |
1799 | 0 | } |
1800 | | |
1801 | | // load the prompt |
1802 | 0 | { |
1803 | 0 | const uint32_t n_token_count = file.read_u32(); |
1804 | |
|
1805 | 0 | if (n_token_count > n_token_capacity) { |
1806 | 0 | LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); |
1807 | 0 | return false; |
1808 | 0 | } |
1809 | | |
1810 | 0 | file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); |
1811 | 0 | *n_token_count_out = n_token_count; |
1812 | 0 | } |
1813 | | |
1814 | | // restore the context state |
1815 | 0 | { |
1816 | 0 | const size_t n_state_size_cur = file.size() - file.tell(); |
1817 | |
|
1818 | 0 | llama_io_read_file io( &file); |
1819 | 0 | const size_t n_read = state_read_data(io); |
1820 | |
|
1821 | 0 | if (n_read != n_state_size_cur) { |
1822 | 0 | LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); |
1823 | 0 | return false; |
1824 | 0 | } |
1825 | 0 | } |
1826 | | |
1827 | 0 | return true; |
1828 | 0 | } |
1829 | | |
1830 | 0 | bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) { |
1831 | 0 | llama_file file(filepath, "wb"); |
1832 | |
|
1833 | 0 | file.write_u32(LLAMA_SESSION_MAGIC); |
1834 | 0 | file.write_u32(LLAMA_SESSION_VERSION); |
1835 | | |
1836 | | // save the prompt |
1837 | 0 | file.write_u32((uint32_t) n_token_count); |
1838 | 0 | file.write_raw(tokens, sizeof(llama_token) * n_token_count); |
1839 | | |
1840 | | // save the context state using stream saving |
1841 | 0 | llama_io_write_file io(&file); |
1842 | 0 | state_write_data(io); |
1843 | |
|
1844 | 0 | return true; |
1845 | 0 | } |
1846 | | |
1847 | 0 | size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
1848 | 0 | llama_file file(filepath, "rb"); |
1849 | | |
1850 | | // version checks |
1851 | 0 | { |
1852 | 0 | const uint32_t magic = file.read_u32(); |
1853 | 0 | const uint32_t version = file.read_u32(); |
1854 | |
|
1855 | 0 | if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { |
1856 | 0 | LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); |
1857 | 0 | return 0; |
1858 | 0 | } |
1859 | 0 | } |
1860 | | |
1861 | | // load the prompt |
1862 | 0 | { |
1863 | 0 | const uint32_t n_token_count = file.read_u32(); |
1864 | |
|
1865 | 0 | if (n_token_count > n_token_capacity) { |
1866 | 0 | LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); |
1867 | 0 | return 0; |
1868 | 0 | } |
1869 | | |
1870 | 0 | file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); |
1871 | 0 | *n_token_count_out = n_token_count; |
1872 | 0 | } |
1873 | | |
1874 | | // restore the context state |
1875 | 0 | { |
1876 | 0 | const size_t state_size = file.size() - file.tell(); |
1877 | 0 | llama_io_read_file io(&file); |
1878 | 0 | const size_t nread = state_seq_read_data(io, seq_id, 0); |
1879 | 0 | if (!nread) { |
1880 | 0 | LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); |
1881 | 0 | return 0; |
1882 | 0 | } |
1883 | 0 | GGML_ASSERT(nread <= state_size); |
1884 | 0 | GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); |
1885 | 0 | } |
1886 | | |
1887 | 0 | return file.tell(); |
1888 | 0 | } |
1889 | | |
1890 | 0 | size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) { |
1891 | 0 | llama_file file(filepath, "wb"); |
1892 | |
|
1893 | 0 | file.write_u32(LLAMA_STATE_SEQ_MAGIC); |
1894 | 0 | file.write_u32(LLAMA_STATE_SEQ_VERSION); |
1895 | | |
1896 | | // save the prompt |
1897 | 0 | file.write_u32((uint32_t) n_token_count); |
1898 | 0 | file.write_raw(tokens, sizeof(llama_token) * n_token_count); |
1899 | | |
1900 | | // save the context state using stream saving |
1901 | 0 | llama_io_write_file io(&file); |
1902 | 0 | state_seq_write_data(io, seq_id, 0); |
1903 | |
|
1904 | 0 | const size_t res = file.tell(); |
1905 | 0 | GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); |
1906 | |
|
1907 | 0 | return res; |
1908 | 0 | } |
1909 | | |
1910 | 0 | size_t llama_context::state_write_data(llama_io_write_i & io) { |
1911 | 0 | LLAMA_LOG_DEBUG("%s: writing state\n", __func__); |
1912 | | |
1913 | | // write model info |
1914 | 0 | { |
1915 | 0 | LLAMA_LOG_DEBUG("%s: - writing model info\n", __func__); |
1916 | |
|
1917 | 0 | const std::string arch_str = llm_arch_name(model.arch); |
1918 | 0 | io.write_string(arch_str); |
1919 | | // TODO: add more model-specific info which should prevent loading the session file if not identical |
1920 | 0 | } |
1921 | | |
1922 | | // write output ids |
1923 | 0 | { |
1924 | 0 | LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); |
1925 | |
|
1926 | 0 | const auto n_outputs = this->n_outputs; |
1927 | 0 | const auto & output_ids = this->output_ids; |
1928 | |
|
1929 | 0 | std::vector<int32_t> w_output_pos; |
1930 | |
|
1931 | 0 | w_output_pos.resize(n_outputs); |
1932 | | |
1933 | | // build a more compact representation of the output ids |
1934 | 0 | for (size_t i = 0; i < n_batch(); ++i) { |
1935 | | // map an output id to a position in the batch |
1936 | 0 | int64_t pos = output_ids[i]; |
1937 | 0 | if (pos >= 0) { |
1938 | 0 | GGML_ASSERT(pos < n_outputs); |
1939 | 0 | w_output_pos[pos] = i; |
1940 | 0 | } |
1941 | 0 | } |
1942 | |
|
1943 | 0 | io.write(&n_outputs, sizeof(n_outputs)); |
1944 | |
|
1945 | 0 | if (n_outputs) { |
1946 | 0 | io.write(w_output_pos.data(), n_outputs * sizeof(int32_t)); |
1947 | 0 | } |
1948 | 0 | } |
1949 | | |
1950 | | // write logits |
1951 | 0 | { |
1952 | 0 | LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); |
1953 | |
|
1954 | 0 | const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens()); |
1955 | |
|
1956 | 0 | io.write(&logits_size, sizeof(logits_size)); |
1957 | |
|
1958 | 0 | if (logits_size) { |
1959 | 0 | io.write(logits, logits_size * sizeof(float)); |
1960 | 0 | } |
1961 | 0 | } |
1962 | | |
1963 | | // write embeddings |
1964 | 0 | { |
1965 | 0 | LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__); |
1966 | |
|
1967 | 0 | const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd); |
1968 | |
|
1969 | 0 | io.write(&embd_size, sizeof(embd_size)); |
1970 | |
|
1971 | 0 | if (embd_size) { |
1972 | 0 | io.write(embd, embd_size * sizeof(float)); |
1973 | 0 | } |
1974 | 0 | } |
1975 | |
|
1976 | 0 | if (memory != nullptr) { |
1977 | 0 | LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); |
1978 | 0 | memory->state_write(io); |
1979 | 0 | } |
1980 | |
|
1981 | 0 | return io.n_bytes(); |
1982 | 0 | } |
1983 | | |
1984 | 0 | size_t llama_context::state_read_data(llama_io_read_i & io) { |
1985 | 0 | LLAMA_LOG_DEBUG("%s: reading state\n", __func__); |
1986 | | |
1987 | | // read model info |
1988 | 0 | { |
1989 | 0 | LLAMA_LOG_DEBUG("%s: - reading model info\n", __func__); |
1990 | |
|
1991 | 0 | const std::string cur_arch_str = llm_arch_name(model.arch); |
1992 | |
|
1993 | 0 | std::string arch_str; |
1994 | 0 | io.read_string(arch_str); |
1995 | 0 | if (cur_arch_str != arch_str) { |
1996 | 0 | throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); |
1997 | 0 | } |
1998 | | // TODO: add more info which needs to be identical but which is not verified otherwise |
1999 | 0 | } |
2000 | | |
2001 | | // read output ids |
2002 | 0 | { |
2003 | 0 | LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__); |
2004 | |
|
2005 | 0 | auto n_outputs = this->n_outputs; |
2006 | 0 | io.read_to(&n_outputs, sizeof(n_outputs)); |
2007 | |
|
2008 | 0 | if (n_outputs > output_reserve(n_outputs)) { |
2009 | 0 | throw std::runtime_error("could not reserve outputs"); |
2010 | 0 | } |
2011 | | |
2012 | 0 | std::vector<int32_t> output_pos; |
2013 | |
|
2014 | 0 | if (n_outputs) { |
2015 | 0 | output_pos.resize(n_outputs); |
2016 | 0 | io.read_to(output_pos.data(), n_outputs * sizeof(int32_t)); |
2017 | |
|
2018 | 0 | for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { |
2019 | 0 | int32_t id = output_pos[i]; |
2020 | 0 | if ((uint32_t) id >= n_batch()) { |
2021 | 0 | throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch())); |
2022 | 0 | } |
2023 | 0 | this->output_ids[id] = i; |
2024 | 0 | } |
2025 | | |
2026 | 0 | this->n_outputs = n_outputs; |
2027 | 0 | } |
2028 | 0 | } |
2029 | | |
2030 | | // read logits |
2031 | 0 | { |
2032 | 0 | LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__); |
2033 | |
|
2034 | 0 | uint64_t logits_size; |
2035 | 0 | io.read_to(&logits_size, sizeof(logits_size)); |
2036 | |
|
2037 | 0 | if (this->logits_size < logits_size) { |
2038 | 0 | throw std::runtime_error("logits buffer too small"); |
2039 | 0 | } |
2040 | | |
2041 | 0 | if (logits_size) { |
2042 | 0 | io.read_to(this->logits, logits_size * sizeof(float)); |
2043 | 0 | } |
2044 | 0 | } |
2045 | | |
2046 | | // read embeddings |
2047 | 0 | { |
2048 | 0 | LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__); |
2049 | |
|
2050 | 0 | uint64_t embd_size; |
2051 | 0 | io.read_to(&embd_size, sizeof(embd_size)); |
2052 | |
|
2053 | 0 | if (this->embd_size < embd_size) { |
2054 | 0 | throw std::runtime_error("embeddings buffer too small"); |
2055 | 0 | } |
2056 | | |
2057 | 0 | if (embd_size) { |
2058 | 0 | io.read_to(this->embd, embd_size * sizeof(float)); |
2059 | 0 | } |
2060 | 0 | } |
2061 | | |
2062 | 0 | if (memory) { |
2063 | 0 | LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); |
2064 | |
|
2065 | 0 | memory->state_read(io); |
2066 | 0 | } |
2067 | |
|
2068 | 0 | return io.n_bytes(); |
2069 | 0 | } |
2070 | | |
2071 | 0 | size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2072 | 0 | GGML_UNUSED(seq_id); |
2073 | |
|
2074 | 0 | if (memory) { |
2075 | 0 | memory->state_write(io, seq_id, flags); |
2076 | 0 | } |
2077 | |
|
2078 | 0 | return io.n_bytes(); |
2079 | 0 | } |
2080 | | |
2081 | 0 | size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2082 | 0 | GGML_UNUSED(seq_id); |
2083 | |
|
2084 | 0 | if (memory) { |
2085 | 0 | memory->state_read(io, seq_id, flags); |
2086 | 0 | } |
2087 | |
|
2088 | 0 | return io.n_bytes(); |
2089 | 0 | } |
2090 | | |
2091 | | // |
2092 | | // perf |
2093 | | // |
2094 | | |
2095 | 0 | llama_perf_context_data llama_context::perf_get_data() const { |
2096 | 0 | llama_perf_context_data data = {}; |
2097 | |
|
2098 | 0 | data.t_start_ms = 1e-3 * t_start_us; |
2099 | 0 | data.t_load_ms = 1e-3 * t_load_us; |
2100 | 0 | data.t_p_eval_ms = 1e-3 * t_p_eval_us; |
2101 | 0 | data.t_eval_ms = 1e-3 * t_eval_us; |
2102 | 0 | data.n_p_eval = std::max(1, n_p_eval); |
2103 | 0 | data.n_eval = std::max(1, n_eval); |
2104 | 0 | data.n_reused = std::max(0, n_reused); |
2105 | |
|
2106 | 0 | return data; |
2107 | 0 | } |
2108 | | |
2109 | 0 | void llama_context::perf_reset() { |
2110 | 0 | t_start_us = ggml_time_us(); |
2111 | 0 | t_eval_us = n_eval = 0; |
2112 | 0 | t_p_eval_us = n_p_eval = 0; |
2113 | 0 | n_reused = 0; |
2114 | 0 | } |
2115 | | |
2116 | 0 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const { |
2117 | 0 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret; |
2118 | 0 | for (const auto & [buft, size] : model.memory_breakdown()) { |
2119 | 0 | ret[buft].model += size; |
2120 | 0 | } |
2121 | 0 | if (memory) { |
2122 | 0 | for (const auto & [buft, size] : memory->memory_breakdown()) { |
2123 | 0 | ret[buft].context += size; |
2124 | 0 | } |
2125 | 0 | } |
2126 | 0 | if (model.hparams.no_alloc) { |
2127 | 0 | for (size_t i = 0; i < backends.size(); ++i) { |
2128 | 0 | ggml_backend_t backend = backends[i].get(); |
2129 | 0 | ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend); |
2130 | 0 | ret[buft].compute += backend_buf_exp_size[i]; |
2131 | 0 | } |
2132 | 0 | } else { |
2133 | 0 | for (const auto & backend_ptr : backends) { |
2134 | 0 | ggml_backend_t backend = backend_ptr.get(); |
2135 | 0 | ggml_backend_buffer_type_t buft = ggml_backend_sched_get_buffer_type(sched.get(), backend); |
2136 | 0 | ret[buft].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend); |
2137 | 0 | } |
2138 | 0 | } |
2139 | 0 | return ret; |
2140 | 0 | } |
2141 | | |
2142 | | // |
2143 | | // training |
2144 | | // |
2145 | | |
2146 | 0 | static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) { |
2147 | 0 | if (!tensor || tensor->type != GGML_TYPE_F32) { |
2148 | 0 | return; |
2149 | 0 | } |
2150 | 0 | if (!param_filter(tensor, userdata)) { |
2151 | 0 | return; |
2152 | 0 | } |
2153 | 0 | if (strcmp(tensor->name, "token_embd.weight") == 0) { |
2154 | 0 | return; // FIXME |
2155 | 0 | } |
2156 | 0 | if (strcmp(tensor->name, "rope_freqs.weight") == 0) { |
2157 | 0 | return; // FIXME |
2158 | 0 | } |
2159 | 0 | ggml_set_param(tensor); |
2160 | 0 | } |
2161 | | |
2162 | 0 | void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) { |
2163 | 0 | GGML_ASSERT(!opt_ctx); |
2164 | 0 | model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx(); |
2165 | 0 | const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train); |
2166 | 0 | const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); |
2167 | 0 | GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0); |
2168 | 0 | GGML_ASSERT(n_batch % n_ubatch == 0); |
2169 | |
|
2170 | 0 | ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY); |
2171 | 0 | opt_params.opt_period = n_batch / n_ubatch; |
2172 | 0 | opt_params.get_opt_pars = lopt_params.get_opt_pars; |
2173 | 0 | opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud; |
2174 | 0 | opt_params.optimizer = lopt_params.optimizer_type; |
2175 | 0 | opt_ctx = ggml_opt_init(opt_params); |
2176 | |
|
2177 | 0 | llama_opt_param_filter param_filter = lopt_params.param_filter; |
2178 | 0 | void * param_filter_ud = lopt_params.param_filter_ud; |
2179 | | |
2180 | | //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME |
2181 | 0 | llama_set_param(model->type_embd, param_filter, param_filter_ud); |
2182 | 0 | llama_set_param(model->pos_embd, param_filter, param_filter_ud); |
2183 | 0 | llama_set_param(model->tok_norm, param_filter, param_filter_ud); |
2184 | 0 | llama_set_param(model->tok_norm_b, param_filter, param_filter_ud); |
2185 | 0 | llama_set_param(model->output_norm, param_filter, param_filter_ud); |
2186 | 0 | llama_set_param(model->output_norm_b, param_filter, param_filter_ud); |
2187 | 0 | llama_set_param(model->output, param_filter, param_filter_ud); |
2188 | 0 | llama_set_param(model->output_b, param_filter, param_filter_ud); |
2189 | 0 | llama_set_param(model->output_norm_enc, param_filter, param_filter_ud); |
2190 | 0 | llama_set_param(model->cls, param_filter, param_filter_ud); |
2191 | 0 | llama_set_param(model->cls_b, param_filter, param_filter_ud); |
2192 | 0 | llama_set_param(model->cls_out, param_filter, param_filter_ud); |
2193 | 0 | llama_set_param(model->cls_out_b, param_filter, param_filter_ud); |
2194 | |
|
2195 | 0 | for (struct llama_layer & layer : model->layers) { |
2196 | 0 | for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { |
2197 | 0 | llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud); |
2198 | 0 | } |
2199 | 0 | } |
2200 | 0 | } |
2201 | | |
2202 | | void llama_context::opt_epoch_iter( |
2203 | | ggml_opt_dataset_t dataset, |
2204 | | ggml_opt_result_t result, |
2205 | | const std::vector<llama_token> & tokens, |
2206 | | const std::vector<llama_token> & labels_sparse, |
2207 | | llama_batch & batch, |
2208 | | ggml_opt_epoch_callback callback, |
2209 | | bool train, |
2210 | | int64_t idata_in_loop, |
2211 | | int64_t ndata_in_loop, |
2212 | 0 | int64_t t_loop_start) { |
2213 | 0 | GGML_ASSERT(opt_ctx); |
2214 | 0 | const uint32_t n_ctx = llama_model_n_ctx_train(&model); |
2215 | 0 | const uint32_t n_batch = std::min(this->n_batch(), n_ctx); |
2216 | 0 | const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); |
2217 | |
|
2218 | 0 | memory->clear(true); |
2219 | |
|
2220 | 0 | for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { |
2221 | 0 | batch.n_tokens = n_batch; |
2222 | 0 | for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) { |
2223 | 0 | batch.token [pos_batch] = tokens[pos_ctx + pos_batch]; |
2224 | 0 | batch.pos [pos_batch] = pos_ctx + pos_batch; |
2225 | 0 | batch.n_seq_id[pos_batch] = 1; |
2226 | 0 | batch.seq_id [pos_batch][0] = 0; |
2227 | 0 | batch.logits [pos_batch] = true; |
2228 | 0 | } |
2229 | |
|
2230 | 0 | if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { |
2231 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); |
2232 | 0 | return; |
2233 | 0 | } |
2234 | | |
2235 | 0 | const uint32_t n_tokens_all = balloc->get_n_tokens(); |
2236 | |
|
2237 | 0 | n_queued_tokens += n_tokens_all; |
2238 | |
|
2239 | 0 | embd_seq.clear(); |
2240 | |
|
2241 | 0 | uint32_t n_outputs_all = n_tokens_all; |
2242 | |
|
2243 | 0 | auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true); |
2244 | 0 | if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { |
2245 | 0 | LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); |
2246 | 0 | break; |
2247 | 0 | } |
2248 | | |
2249 | | // reserve output buffer |
2250 | 0 | if (output_reserve(n_outputs_all) < n_outputs_all) { |
2251 | 0 | LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); |
2252 | 0 | GGML_ABORT("TODO: handle this error"); |
2253 | 0 | }; |
2254 | |
|
2255 | 0 | uint32_t pos_batch = 0; |
2256 | 0 | do { |
2257 | 0 | const auto & ubatch = mctx->get_ubatch(); |
2258 | |
|
2259 | 0 | n_outputs = ubatch.n_tokens; |
2260 | |
|
2261 | 0 | if (!mctx->apply()) { |
2262 | 0 | LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__); |
2263 | 0 | break; |
2264 | 0 | } |
2265 | | |
2266 | 0 | auto * res = gf_res_prev.get(); |
2267 | |
|
2268 | 0 | const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); |
2269 | |
|
2270 | 0 | res->reset(); |
2271 | |
|
2272 | 0 | auto * gf = model.build_graph(gparams); |
2273 | |
|
2274 | 0 | struct ggml_context * ctx_compute_opt; |
2275 | 0 | { |
2276 | 0 | const size_t size_gf = ggml_graph_size(gf); |
2277 | 0 | const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true); |
2278 | 0 | struct ggml_init_params params = { |
2279 | 0 | /*.mem_size =*/ size_meta, |
2280 | 0 | /*.mem_buffer =*/ nullptr, |
2281 | 0 | /*.no_alloc =*/ true, |
2282 | 0 | }; |
2283 | 0 | ctx_compute_opt = ggml_init(params); |
2284 | 0 | } |
2285 | 0 | ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); |
2286 | 0 | ggml_opt_alloc(opt_ctx, train); |
2287 | |
|
2288 | 0 | res->set_inputs(&ubatch); |
2289 | 0 | { |
2290 | 0 | struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); |
2291 | 0 | GGML_ASSERT(labels->ne[1] == n_ubatch); |
2292 | 0 | ggml_set_zero(labels); |
2293 | 0 | const float onef = 1.0f; |
2294 | 0 | for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) { |
2295 | 0 | const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch; |
2296 | 0 | GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]); |
2297 | 0 | ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float)); |
2298 | 0 | } |
2299 | 0 | } |
2300 | 0 | ggml_opt_eval(opt_ctx, result); |
2301 | 0 | if (callback) { |
2302 | 0 | callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start); |
2303 | 0 | } |
2304 | 0 | ggml_free(ctx_compute_opt); |
2305 | |
|
2306 | 0 | pos_batch += ubatch.n_tokens; |
2307 | 0 | } while (mctx->next()); |
2308 | 0 | } |
2309 | 0 | } |
2310 | | |
2311 | | void llama_context::opt_epoch( |
2312 | | ggml_opt_dataset_t dataset, |
2313 | | ggml_opt_result_t result_train, |
2314 | | ggml_opt_result_t result_eval, |
2315 | | int64_t idata_split, |
2316 | | ggml_opt_epoch_callback callback_train, |
2317 | 0 | ggml_opt_epoch_callback callback_eval) { |
2318 | 0 | const uint32_t n_ctx = this->n_ctx(); |
2319 | 0 | const uint32_t n_batch = std::min(cparams.n_batch, n_ctx); |
2320 | 0 | const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch); |
2321 | 0 | const int64_t ndata = ggml_opt_dataset_ndata(dataset); |
2322 | |
|
2323 | 0 | GGML_ASSERT(idata_split >= 0); |
2324 | 0 | GGML_ASSERT(idata_split <= ndata); |
2325 | |
|
2326 | 0 | const uint32_t ubatch_per_ctx = n_ctx / n_ubatch; |
2327 | |
|
2328 | 0 | struct llama_batch batch = llama_batch_init(n_batch, 0, 1); |
2329 | 0 | std::vector<llama_token> tokens(n_ctx); |
2330 | 0 | std::vector<llama_token> labels_sparse(n_ctx); |
2331 | |
|
2332 | 0 | int64_t idata = 0; |
2333 | |
|
2334 | 0 | int64_t t_loop_start = ggml_time_us(); |
2335 | 0 | int64_t ndata_in_loop = idata_split*ubatch_per_ctx; |
2336 | 0 | for (; idata < idata_split; ++idata) { |
2337 | 0 | constexpr bool train = true; |
2338 | 0 | const int64_t idata_in_loop = idata*ubatch_per_ctx; |
2339 | |
|
2340 | 0 | ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); |
2341 | 0 | opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch, |
2342 | 0 | callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start); |
2343 | 0 | } |
2344 | |
|
2345 | 0 | t_loop_start = ggml_time_us(); |
2346 | 0 | ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx; |
2347 | 0 | for (; idata < ndata; ++idata) { |
2348 | 0 | constexpr bool train = false; |
2349 | 0 | const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx; |
2350 | |
|
2351 | 0 | ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); |
2352 | 0 | opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch, |
2353 | 0 | callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start); |
2354 | 0 | } |
2355 | |
|
2356 | 0 | llama_batch_free(batch); |
2357 | 0 | } |
2358 | | |
2359 | | // |
2360 | | // interface implementation |
2361 | | // |
2362 | | |
2363 | 0 | llama_context_params llama_context_default_params() { |
2364 | 0 | llama_context_params result = { |
2365 | 0 | /*.n_ctx =*/ 512, |
2366 | 0 | /*.n_batch =*/ 2048, |
2367 | 0 | /*.n_ubatch =*/ 512, |
2368 | 0 | /*.n_seq_max =*/ 1, |
2369 | 0 | /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default |
2370 | 0 | /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, |
2371 | 0 | /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, |
2372 | 0 | /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, |
2373 | 0 | /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, |
2374 | 0 | /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO, |
2375 | 0 | /*.rope_freq_base =*/ 0.0f, |
2376 | 0 | /*.rope_freq_scale =*/ 0.0f, |
2377 | 0 | /*.yarn_ext_factor =*/ -1.0f, |
2378 | 0 | /*.yarn_attn_factor =*/ -1.0f, |
2379 | 0 | /*.yarn_beta_fast =*/ -1.0f, |
2380 | 0 | /*.yarn_beta_slow =*/ -1.0f, |
2381 | 0 | /*.yarn_orig_ctx =*/ 0, |
2382 | 0 | /*.defrag_thold =*/ -1.0f, |
2383 | 0 | /*.cb_eval =*/ nullptr, |
2384 | 0 | /*.cb_eval_user_data =*/ nullptr, |
2385 | 0 | /*.type_k =*/ GGML_TYPE_F16, |
2386 | 0 | /*.type_v =*/ GGML_TYPE_F16, |
2387 | 0 | /*.abort_callback =*/ nullptr, |
2388 | 0 | /*.abort_callback_data =*/ nullptr, |
2389 | 0 | /*.embeddings =*/ false, |
2390 | 0 | /*.offload_kqv =*/ true, |
2391 | 0 | /*.no_perf =*/ true, |
2392 | 0 | /*.op_offload =*/ true, |
2393 | 0 | /*.swa_full =*/ true, |
2394 | 0 | /*.kv_unified =*/ false, |
2395 | 0 | }; |
2396 | |
|
2397 | 0 | return result; |
2398 | 0 | } |
2399 | | |
2400 | | llama_context * llama_init_from_model( |
2401 | | llama_model * model, |
2402 | 0 | llama_context_params params) { |
2403 | 0 | if (!model) { |
2404 | 0 | LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__); |
2405 | 0 | return nullptr; |
2406 | 0 | } |
2407 | | |
2408 | 0 | if (params.n_batch == 0 && params.n_ubatch == 0) { |
2409 | 0 | LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__); |
2410 | 0 | return nullptr; |
2411 | 0 | } |
2412 | | |
2413 | 0 | if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) { |
2414 | 0 | LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__); |
2415 | 0 | return nullptr; |
2416 | 0 | } |
2417 | | |
2418 | 0 | if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) { |
2419 | 0 | LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); |
2420 | 0 | params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; |
2421 | 0 | } |
2422 | |
|
2423 | 0 | if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { |
2424 | 0 | const uint32_t blck_size = ggml_blck_size(params.type_k); |
2425 | 0 | if (model->hparams.n_embd_head_k % blck_size != 0) { |
2426 | 0 | LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", |
2427 | 0 | __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k); |
2428 | 0 | return nullptr; |
2429 | 0 | } |
2430 | 0 | } |
2431 | | |
2432 | 0 | if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { |
2433 | 0 | const uint32_t blck_size = ggml_blck_size(params.type_v); |
2434 | 0 | if (model->hparams.n_embd_head_v % blck_size != 0) { |
2435 | 0 | LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n", |
2436 | 0 | __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v); |
2437 | 0 | return nullptr; |
2438 | 0 | } |
2439 | 0 | } |
2440 | | |
2441 | 0 | if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) { |
2442 | 0 | LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); |
2443 | 0 | return nullptr; |
2444 | 0 | } |
2445 | | |
2446 | 0 | if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED && |
2447 | 0 | params.pooling_type != model->hparams.pooling_type) { |
2448 | | //user-specified pooling-type is different from the model default |
2449 | 0 | LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__, |
2450 | 0 | model->hparams.pooling_type, params.pooling_type); |
2451 | 0 | } |
2452 | |
|
2453 | 0 | try { |
2454 | 0 | auto * ctx = new llama_context(*model, params); |
2455 | 0 | return ctx; |
2456 | 0 | } catch (const std::exception & err) { |
2457 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what()); |
2458 | 0 | } |
2459 | | |
2460 | 0 | return nullptr; |
2461 | 0 | } |
2462 | | |
2463 | | // deprecated |
2464 | | llama_context * llama_new_context_with_model( |
2465 | | llama_model * model, |
2466 | 0 | llama_context_params params) { |
2467 | 0 | return llama_init_from_model(model, params); |
2468 | 0 | } |
2469 | | |
2470 | 0 | void llama_free(llama_context * ctx) { |
2471 | 0 | delete ctx; |
2472 | 0 | } |
2473 | | |
2474 | 0 | uint32_t llama_n_ctx(const llama_context * ctx) { |
2475 | 0 | return ctx->n_ctx(); |
2476 | 0 | } |
2477 | | |
2478 | 0 | uint32_t llama_n_ctx_seq(const llama_context * ctx) { |
2479 | 0 | return ctx->n_ctx_seq(); |
2480 | 0 | } |
2481 | | |
2482 | 0 | uint32_t llama_n_batch(const llama_context * ctx) { |
2483 | 0 | return ctx->n_batch(); |
2484 | 0 | } |
2485 | | |
2486 | 0 | uint32_t llama_n_ubatch(const llama_context * ctx) { |
2487 | 0 | return ctx->n_ubatch(); |
2488 | 0 | } |
2489 | | |
2490 | 0 | uint32_t llama_n_seq_max(const llama_context * ctx) { |
2491 | 0 | return ctx->n_seq_max(); |
2492 | 0 | } |
2493 | | |
2494 | 0 | const llama_model * llama_get_model(const llama_context * ctx) { |
2495 | 0 | return &ctx->get_model(); |
2496 | 0 | } |
2497 | | |
2498 | 0 | enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { |
2499 | 0 | return ctx->pooling_type(); |
2500 | 0 | } |
2501 | | |
2502 | | void llama_attach_threadpool( |
2503 | | llama_context * ctx, |
2504 | | ggml_threadpool_t threadpool, |
2505 | 0 | ggml_threadpool_t threadpool_batch) { |
2506 | 0 | ctx->attach_threadpool(threadpool, threadpool_batch); |
2507 | 0 | } |
2508 | | |
2509 | 0 | void llama_detach_threadpool(llama_context * ctx) { |
2510 | 0 | ctx->detach_threadpool(); |
2511 | 0 | } |
2512 | | |
2513 | 0 | void llama_set_n_threads(llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) { |
2514 | 0 | ctx->set_n_threads(n_threads, n_threads_batch); |
2515 | 0 | } |
2516 | | |
2517 | 0 | int32_t llama_n_threads(llama_context * ctx) { |
2518 | 0 | return ctx->n_threads(); |
2519 | 0 | } |
2520 | | |
2521 | 0 | int32_t llama_n_threads_batch(llama_context * ctx) { |
2522 | 0 | return ctx->n_threads_batch(); |
2523 | 0 | } |
2524 | | |
2525 | 0 | void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { |
2526 | 0 | ctx->set_abort_callback(abort_callback, abort_callback_data); |
2527 | 0 | } |
2528 | | |
2529 | 0 | void llama_set_embeddings(llama_context * ctx, bool embeddings) { |
2530 | 0 | ctx->set_embeddings(embeddings); |
2531 | 0 | } |
2532 | | |
2533 | 0 | void llama_set_causal_attn(llama_context * ctx, bool causal_attn) { |
2534 | 0 | ctx->set_causal_attn(causal_attn); |
2535 | 0 | } |
2536 | | |
2537 | 0 | void llama_set_warmup(llama_context * ctx, bool warmup) { |
2538 | 0 | ctx->set_warmup(warmup); |
2539 | 0 | } |
2540 | | |
2541 | 0 | void llama_synchronize(llama_context * ctx) { |
2542 | 0 | ctx->synchronize(); |
2543 | 0 | } |
2544 | | |
2545 | 0 | float * llama_get_logits(llama_context * ctx) { |
2546 | 0 | ctx->synchronize(); |
2547 | |
|
2548 | 0 | return ctx->get_logits(); |
2549 | 0 | } |
2550 | | |
2551 | 0 | float * llama_get_logits_ith(llama_context * ctx, int32_t i) { |
2552 | 0 | ctx->synchronize(); |
2553 | |
|
2554 | 0 | return ctx->get_logits_ith(i); |
2555 | 0 | } |
2556 | | |
2557 | 0 | float * llama_get_embeddings(llama_context * ctx) { |
2558 | 0 | ctx->synchronize(); |
2559 | |
|
2560 | 0 | return ctx->get_embeddings(); |
2561 | 0 | } |
2562 | | |
2563 | 0 | float * llama_get_embeddings_ith(llama_context * ctx, int32_t i) { |
2564 | 0 | ctx->synchronize(); |
2565 | |
|
2566 | 0 | return ctx->get_embeddings_ith(i); |
2567 | 0 | } |
2568 | | |
2569 | 0 | float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { |
2570 | 0 | ctx->synchronize(); |
2571 | |
|
2572 | 0 | return ctx->get_embeddings_seq(seq_id); |
2573 | 0 | } |
2574 | | |
2575 | | // llama adapter API |
2576 | | |
2577 | | int32_t llama_set_adapter_lora( |
2578 | | llama_context * ctx, |
2579 | | llama_adapter_lora * adapter, |
2580 | 0 | float scale) { |
2581 | 0 | ctx->set_adapter_lora(adapter, scale); |
2582 | |
|
2583 | 0 | return 0; |
2584 | 0 | } |
2585 | | |
2586 | | int32_t llama_rm_adapter_lora( |
2587 | | llama_context * ctx, |
2588 | 0 | llama_adapter_lora * adapter) { |
2589 | 0 | bool res = ctx->rm_adapter_lora(adapter); |
2590 | |
|
2591 | 0 | return res ? 0 : -1; |
2592 | 0 | } |
2593 | | |
2594 | 0 | void llama_clear_adapter_lora(llama_context * ctx) { |
2595 | 0 | ctx->clear_adapter_lora(); |
2596 | 0 | } |
2597 | | |
2598 | | int32_t llama_apply_adapter_cvec( |
2599 | | llama_context * ctx, |
2600 | | const float * data, |
2601 | | size_t len, |
2602 | | int32_t n_embd, |
2603 | | int32_t il_start, |
2604 | 0 | int32_t il_end) { |
2605 | 0 | bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end); |
2606 | |
|
2607 | 0 | return res ? 0 : -1; |
2608 | 0 | } |
2609 | | |
2610 | | // |
2611 | | // memory |
2612 | | // |
2613 | | |
2614 | 0 | llama_memory_t llama_get_memory(const struct llama_context * ctx) { |
2615 | 0 | return ctx->get_memory(); |
2616 | 0 | } |
2617 | | |
2618 | 0 | void llama_memory_clear(llama_memory_t mem, bool data) { |
2619 | 0 | if (!mem) { |
2620 | 0 | return; |
2621 | 0 | } |
2622 | | |
2623 | 0 | mem->clear(data); |
2624 | 0 | } |
2625 | | |
2626 | | bool llama_memory_seq_rm( |
2627 | | llama_memory_t mem, |
2628 | | llama_seq_id seq_id, |
2629 | | llama_pos p0, |
2630 | 0 | llama_pos p1) { |
2631 | 0 | if (!mem) { |
2632 | 0 | return true; |
2633 | 0 | } |
2634 | | |
2635 | 0 | return mem->seq_rm(seq_id, p0, p1); |
2636 | 0 | } |
2637 | | |
2638 | | void llama_memory_seq_cp( |
2639 | | llama_memory_t mem, |
2640 | | llama_seq_id seq_id_src, |
2641 | | llama_seq_id seq_id_dst, |
2642 | | llama_pos p0, |
2643 | 0 | llama_pos p1) { |
2644 | 0 | if (!mem) { |
2645 | 0 | return; |
2646 | 0 | } |
2647 | | |
2648 | 0 | mem->seq_cp(seq_id_src, seq_id_dst, p0, p1); |
2649 | 0 | } |
2650 | | |
2651 | | void llama_memory_seq_keep( |
2652 | | llama_memory_t mem, |
2653 | 0 | llama_seq_id seq_id) { |
2654 | 0 | if (!mem) { |
2655 | 0 | return; |
2656 | 0 | } |
2657 | | |
2658 | 0 | mem->seq_keep(seq_id); |
2659 | 0 | } |
2660 | | |
2661 | | void llama_memory_seq_add( |
2662 | | llama_memory_t mem, |
2663 | | llama_seq_id seq_id, |
2664 | | llama_pos p0, |
2665 | | llama_pos p1, |
2666 | 0 | llama_pos delta) { |
2667 | 0 | if (!mem) { |
2668 | 0 | return; |
2669 | 0 | } |
2670 | | |
2671 | 0 | mem->seq_add(seq_id, p0, p1, delta); |
2672 | 0 | } |
2673 | | |
2674 | | void llama_memory_seq_div( |
2675 | | llama_memory_t mem, |
2676 | | llama_seq_id seq_id, |
2677 | | llama_pos p0, |
2678 | | llama_pos p1, |
2679 | 0 | int d) { |
2680 | 0 | if (!mem) { |
2681 | 0 | return; |
2682 | 0 | } |
2683 | | |
2684 | 0 | mem->seq_div(seq_id, p0, p1, d); |
2685 | 0 | } |
2686 | | |
2687 | | llama_pos llama_memory_seq_pos_min( |
2688 | | llama_memory_t mem, |
2689 | 0 | llama_seq_id seq_id) { |
2690 | 0 | if (!mem) { |
2691 | 0 | return -1; |
2692 | 0 | } |
2693 | | |
2694 | 0 | return mem->seq_pos_min(seq_id); |
2695 | 0 | } |
2696 | | |
2697 | | llama_pos llama_memory_seq_pos_max( |
2698 | | llama_memory_t mem, |
2699 | 0 | llama_seq_id seq_id) { |
2700 | 0 | if (!mem) { |
2701 | 0 | return -1; |
2702 | 0 | } |
2703 | | |
2704 | 0 | return mem->seq_pos_max(seq_id); |
2705 | 0 | } |
2706 | | |
2707 | 0 | bool llama_memory_can_shift(llama_memory_t mem) { |
2708 | 0 | if (!mem) { |
2709 | 0 | return false; |
2710 | 0 | } |
2711 | | |
2712 | 0 | return mem->get_can_shift(); |
2713 | 0 | } |
2714 | | |
2715 | | // llama state API |
2716 | | |
2717 | | // deprecated |
2718 | 0 | size_t llama_get_state_size(llama_context * ctx) { |
2719 | 0 | return llama_state_get_size(ctx); |
2720 | 0 | } |
2721 | | |
2722 | | // deprecated |
2723 | 0 | size_t llama_copy_state_data(llama_context * ctx, uint8_t * dst) { |
2724 | 0 | return llama_state_get_data(ctx, dst, -1); |
2725 | 0 | } |
2726 | | |
2727 | | // deprecated |
2728 | 0 | size_t llama_set_state_data(llama_context * ctx, const uint8_t * src) { |
2729 | 0 | return llama_state_set_data(ctx, src, -1); |
2730 | 0 | } |
2731 | | |
2732 | | // deprecated |
2733 | 0 | bool llama_load_session_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
2734 | 0 | return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); |
2735 | 0 | } |
2736 | | |
2737 | | // deprecated |
2738 | 0 | bool llama_save_session_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { |
2739 | 0 | return llama_state_save_file(ctx, path_session, tokens, n_token_count); |
2740 | 0 | } |
2741 | | |
2742 | | // Returns the *actual* size of the state. |
2743 | | // Intended to be used when saving to state to a buffer. |
2744 | 0 | size_t llama_state_get_size(llama_context * ctx) { |
2745 | 0 | return ctx->state_get_size(); |
2746 | 0 | } |
2747 | | |
2748 | 0 | size_t llama_state_get_data(llama_context * ctx, uint8_t * dst, size_t size) { |
2749 | 0 | ctx->synchronize(); |
2750 | |
|
2751 | 0 | return ctx->state_get_data(dst, size); |
2752 | 0 | } |
2753 | | |
2754 | | // Sets the state reading from the specified source address |
2755 | 0 | size_t llama_state_set_data(llama_context * ctx, const uint8_t * src, size_t size) { |
2756 | 0 | ctx->synchronize(); |
2757 | |
|
2758 | 0 | return ctx->state_set_data(src, size); |
2759 | 0 | } |
2760 | | |
2761 | 0 | bool llama_state_load_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
2762 | 0 | ctx->synchronize(); |
2763 | |
|
2764 | 0 | try { |
2765 | 0 | return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out); |
2766 | 0 | } catch (const std::exception & err) { |
2767 | 0 | LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what()); |
2768 | 0 | return false; |
2769 | 0 | } |
2770 | 0 | } |
2771 | | |
2772 | 0 | bool llama_state_save_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { |
2773 | 0 | ctx->synchronize(); |
2774 | |
|
2775 | 0 | try { |
2776 | 0 | return ctx->state_save_file(path_session, tokens, n_token_count); |
2777 | 0 | } catch (const std::exception & err) { |
2778 | 0 | LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what()); |
2779 | 0 | return false; |
2780 | 0 | } |
2781 | 0 | } |
2782 | | |
2783 | 0 | size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) { |
2784 | 0 | return llama_state_seq_get_size_ext(ctx, seq_id, 0); |
2785 | 0 | } |
2786 | | |
2787 | 0 | size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { |
2788 | 0 | return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0); |
2789 | 0 | } |
2790 | | |
2791 | 0 | size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { |
2792 | 0 | return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0); |
2793 | 0 | } |
2794 | | |
2795 | 0 | size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2796 | 0 | return ctx->state_seq_get_size(seq_id, flags); |
2797 | 0 | } |
2798 | | |
2799 | 0 | size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2800 | 0 | ctx->synchronize(); |
2801 | |
|
2802 | 0 | return ctx->state_seq_get_data(seq_id, dst, size, flags); |
2803 | 0 | } |
2804 | | |
2805 | 0 | size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2806 | 0 | ctx->synchronize(); |
2807 | |
|
2808 | 0 | return ctx->state_seq_set_data(seq_id, src, size, flags); |
2809 | 0 | } |
2810 | | |
2811 | 0 | size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { |
2812 | 0 | ctx->synchronize(); |
2813 | |
|
2814 | 0 | try { |
2815 | 0 | return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count); |
2816 | 0 | } catch (const std::exception & err) { |
2817 | 0 | LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what()); |
2818 | 0 | return 0; |
2819 | 0 | } |
2820 | 0 | } |
2821 | | |
2822 | 0 | size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
2823 | 0 | ctx->synchronize(); |
2824 | |
|
2825 | 0 | try { |
2826 | 0 | return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out); |
2827 | 0 | } catch (const std::exception & err) { |
2828 | 0 | LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what()); |
2829 | 0 | return 0; |
2830 | 0 | } |
2831 | 0 | } |
2832 | | |
2833 | | /// |
2834 | | |
2835 | | int32_t llama_encode( |
2836 | | llama_context * ctx, |
2837 | 0 | llama_batch batch) { |
2838 | 0 | const int ret = ctx->encode(batch); |
2839 | 0 | if (ret != 0) { |
2840 | 0 | LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); |
2841 | 0 | } |
2842 | |
|
2843 | 0 | return ret; |
2844 | 0 | } |
2845 | | |
2846 | | int32_t llama_decode( |
2847 | | llama_context * ctx, |
2848 | 0 | llama_batch batch) { |
2849 | 0 | const int ret = ctx->decode(batch); |
2850 | 0 | if (ret != 0 && ret != 1) { |
2851 | 0 | LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); |
2852 | 0 | } |
2853 | |
|
2854 | 0 | return ret; |
2855 | 0 | } |
2856 | | |
2857 | | // |
2858 | | // perf |
2859 | | // |
2860 | | |
2861 | 0 | llama_perf_context_data llama_perf_context(const llama_context * ctx) { |
2862 | 0 | llama_perf_context_data data = {}; |
2863 | |
|
2864 | 0 | if (ctx == nullptr) { |
2865 | 0 | return data; |
2866 | 0 | } |
2867 | | |
2868 | 0 | data = ctx->perf_get_data(); |
2869 | |
|
2870 | 0 | return data; |
2871 | 0 | } |
2872 | | |
2873 | 0 | void llama_perf_context_print(const llama_context * ctx) { |
2874 | 0 | const auto data = llama_perf_context(ctx); |
2875 | |
|
2876 | 0 | const double t_end_ms = 1e-3 * ggml_time_us(); |
2877 | |
|
2878 | 0 | LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms); |
2879 | 0 | LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", |
2880 | 0 | __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval); |
2881 | 0 | LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", |
2882 | 0 | __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval); |
2883 | 0 | LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval)); |
2884 | 0 | LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused); |
2885 | 0 | } |
2886 | | |
2887 | 0 | void llama_perf_context_reset(llama_context * ctx) { |
2888 | 0 | ctx->perf_reset(); |
2889 | 0 | } |
2890 | | |
2891 | 0 | void llama_memory_breakdown_print(const struct llama_context * ctx) { |
2892 | 0 | const std::vector<ggml_backend_dev_t> & devices = ctx->get_model().devices; |
2893 | |
|
2894 | 0 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown(); |
2895 | |
|
2896 | 0 | std::vector<std::array<std::string, 9>> table_data; |
2897 | 0 | table_data.reserve(devices.size()); |
2898 | 0 | const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n"; |
2899 | 0 | const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n"; |
2900 | 0 | const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n"; |
2901 | |
|
2902 | 0 | table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"}); |
2903 | |
|
2904 | 0 | constexpr size_t MiB = 1024 * 1024; |
2905 | 0 | const std::vector<std::string> desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "}; |
2906 | | |
2907 | | // track seen buffer types to avoid double counting: |
2908 | 0 | std::set<ggml_backend_buffer_type_t> seen_buffer_types; |
2909 | | |
2910 | | // accumulative memory breakdown for each device and for host: |
2911 | 0 | std::vector<llama_memory_breakdown_data> mb_dev(devices.size()); |
2912 | 0 | llama_memory_breakdown_data mb_host; |
2913 | |
|
2914 | 0 | for (const auto & buft_mb : memory_breakdown) { |
2915 | 0 | ggml_backend_buffer_type_t buft = buft_mb.first; |
2916 | 0 | const llama_memory_breakdown_data & mb = buft_mb.second; |
2917 | 0 | if (ggml_backend_buft_is_host(buft)) { |
2918 | 0 | mb_host.model += mb.model; |
2919 | 0 | mb_host.context += mb.context; |
2920 | 0 | mb_host.compute += mb.compute; |
2921 | 0 | seen_buffer_types.insert(buft); |
2922 | 0 | continue; |
2923 | 0 | } |
2924 | 0 | ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); |
2925 | 0 | if (dev) { |
2926 | 0 | int i_dev = -1; |
2927 | 0 | for (size_t i = 0; i < devices.size(); i++) { |
2928 | 0 | if (devices[i] == dev) { |
2929 | 0 | i_dev = i; |
2930 | 0 | break; |
2931 | 0 | } |
2932 | 0 | } |
2933 | 0 | if (i_dev != -1) { |
2934 | 0 | mb_dev[i_dev].model += mb.model; |
2935 | 0 | mb_dev[i_dev].context += mb.context; |
2936 | 0 | mb_dev[i_dev].compute += mb.compute; |
2937 | 0 | seen_buffer_types.insert(buft); |
2938 | 0 | continue; |
2939 | 0 | } |
2940 | 0 | } |
2941 | 0 | } |
2942 | | |
2943 | | // print memory breakdown for each device: |
2944 | 0 | for (size_t i = 0; i < devices.size(); i++) { |
2945 | 0 | ggml_backend_dev_t dev = devices[i]; |
2946 | 0 | llama_memory_breakdown_data mb = mb_dev[i]; |
2947 | |
|
2948 | 0 | const std::string name = ggml_backend_dev_name(dev); |
2949 | 0 | std::string desc = ggml_backend_dev_description(dev); |
2950 | 0 | for (const std::string & prefix : desc_prefixes_strip) { |
2951 | 0 | if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) { |
2952 | 0 | desc = desc.substr(prefix.length()); |
2953 | 0 | } |
2954 | 0 | } |
2955 | |
|
2956 | 0 | size_t free, total; |
2957 | 0 | ggml_backend_dev_memory(dev, &free, &total); |
2958 | |
|
2959 | 0 | const size_t self = mb.model + mb.context + mb.compute; |
2960 | 0 | const size_t unaccounted = total - self - free; |
2961 | |
|
2962 | 0 | table_data.push_back({ |
2963 | 0 | template_gpu, |
2964 | 0 | " - " + name + " (" + desc + ")", |
2965 | 0 | std::to_string(total / MiB), |
2966 | 0 | std::to_string(free / MiB), |
2967 | 0 | std::to_string(self / MiB), |
2968 | 0 | std::to_string(mb.model / MiB), |
2969 | 0 | std::to_string(mb.context / MiB), |
2970 | 0 | std::to_string(mb.compute / MiB), |
2971 | 0 | std::to_string(unaccounted / MiB)}); |
2972 | 0 | } |
2973 | | |
2974 | | // print memory breakdown for host: |
2975 | 0 | { |
2976 | 0 | const size_t self = mb_host.model + mb_host.context + mb_host.compute; |
2977 | 0 | table_data.push_back({ |
2978 | 0 | template_other, |
2979 | 0 | " - Host", |
2980 | 0 | "", // total |
2981 | 0 | "", // free |
2982 | 0 | std::to_string(self / MiB), |
2983 | 0 | std::to_string(mb_host.model / MiB), |
2984 | 0 | std::to_string(mb_host.context / MiB), |
2985 | 0 | std::to_string(mb_host.compute / MiB), |
2986 | 0 | ""}); // unaccounted |
2987 | 0 | } |
2988 | | |
2989 | | // print memory breakdown for all remaining buffer types: |
2990 | 0 | for (const auto & buft_mb : memory_breakdown) { |
2991 | 0 | ggml_backend_buffer_type_t buft = buft_mb.first; |
2992 | 0 | const llama_memory_breakdown_data & mb = buft_mb.second; |
2993 | 0 | if (seen_buffer_types.count(buft) == 1) { |
2994 | 0 | continue; |
2995 | 0 | } |
2996 | 0 | const std::string name = ggml_backend_buft_name(buft); |
2997 | 0 | const size_t self = mb.model + mb.context + mb.compute; |
2998 | 0 | table_data.push_back({ |
2999 | 0 | template_other, |
3000 | 0 | " - " + name, |
3001 | 0 | "", // total |
3002 | 0 | "", // free |
3003 | 0 | std::to_string(self / MiB), |
3004 | 0 | std::to_string(mb.model / MiB), |
3005 | 0 | std::to_string(mb.context / MiB), |
3006 | 0 | std::to_string(mb.compute / MiB), |
3007 | 0 | ""}); // unaccounted |
3008 | 0 | seen_buffer_types.insert(buft); |
3009 | 0 | } |
3010 | |
|
3011 | 0 | for (size_t j = 1; j < table_data[0].size(); j++) { |
3012 | 0 | size_t max_len = 0; |
3013 | 0 | for (const auto & td : table_data) { |
3014 | 0 | max_len = std::max(max_len, td[j].length()); |
3015 | 0 | } |
3016 | 0 | for (auto & td : table_data) { |
3017 | 0 | td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' '); |
3018 | 0 | } |
3019 | 0 | } |
3020 | 0 | for (const auto & td : table_data) { |
3021 | 0 | LLAMA_LOG_INFO(td[0].c_str(), |
3022 | 0 | __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(), |
3023 | 0 | td[6].c_str(), td[7].c_str(), td[8].c_str()); |
3024 | 0 | } |
3025 | 0 | } |
3026 | | |
3027 | | // |
3028 | | // training |
3029 | | // |
3030 | | |
3031 | 0 | bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) { |
3032 | 0 | GGML_UNUSED(tensor); |
3033 | 0 | GGML_UNUSED(userdata); |
3034 | 0 | return true; |
3035 | 0 | } |
3036 | | |
3037 | 0 | void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) { |
3038 | 0 | ctx->opt_init(model, lopt_params); |
3039 | 0 | } |
3040 | | |
3041 | | void llama_opt_epoch( |
3042 | | struct llama_context * ctx, |
3043 | | ggml_opt_dataset_t dataset, |
3044 | | ggml_opt_result_t result_train, |
3045 | | ggml_opt_result_t result_eval, |
3046 | | int64_t idata_split, |
3047 | | ggml_opt_epoch_callback callback_train, |
3048 | 0 | ggml_opt_epoch_callback callback_eval) { |
3049 | 0 | ctx->opt_epoch( |
3050 | 0 | dataset, |
3051 | 0 | result_train, |
3052 | 0 | result_eval, |
3053 | 0 | idata_split, |
3054 | 0 | callback_train, |
3055 | 0 | callback_eval); |
3056 | 0 | } |