/src/llama.cpp/src/llama-context.cpp
Line | Count | Source |
1 | | #include "llama-context.h" |
2 | | |
3 | | #include "llama-impl.h" |
4 | | #include "llama-batch.h" |
5 | | #include "llama-io.h" |
6 | | #include "llama-memory.h" |
7 | | #include "llama-mmap.h" |
8 | | #include "llama-model.h" |
9 | | |
10 | | #include <cinttypes> |
11 | | #include <cstring> |
12 | | #include <limits> |
13 | | #include <stdexcept> |
14 | | |
15 | | // |
16 | | // llama_context |
17 | | // |
18 | | |
19 | | llama_context::llama_context( |
20 | | const llama_model & model, |
21 | | llama_context_params params) : |
22 | 0 | model(model), |
23 | 0 | balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) { |
24 | | // TODO warning when creating llama_context with awkward ctx size that is not a power of 2, |
25 | | // may need to be backend-dependent |
26 | 0 | LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); |
27 | |
|
28 | 0 | t_start_us = model.t_start_us; |
29 | 0 | t_load_us = model.t_load_us; |
30 | |
|
31 | 0 | const auto & hparams = model.hparams; |
32 | |
|
33 | 0 | cparams.n_seq_max = std::max(1u, params.n_seq_max); |
34 | 0 | if (cparams.n_seq_max > LLAMA_MAX_SEQ) { |
35 | 0 | throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); |
36 | 0 | } |
37 | | |
38 | 0 | cparams.n_threads = params.n_threads; |
39 | 0 | cparams.n_threads_batch = params.n_threads_batch; |
40 | 0 | cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; |
41 | 0 | cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; |
42 | 0 | cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; |
43 | 0 | cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; |
44 | 0 | cparams.embeddings = params.embeddings; |
45 | 0 | cparams.offload_kqv = params.offload_kqv; |
46 | 0 | cparams.no_perf = params.no_perf; |
47 | 0 | cparams.pooling_type = params.pooling_type; |
48 | 0 | cparams.warmup = false; |
49 | |
|
50 | 0 | cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; |
51 | 0 | cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; |
52 | 0 | cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; |
53 | |
|
54 | 0 | cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : |
55 | 0 | hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn : |
56 | 0 | hparams.n_ctx_train; |
57 | |
|
58 | 0 | cparams.cb_eval = params.cb_eval; |
59 | 0 | cparams.cb_eval_user_data = params.cb_eval_user_data; |
60 | |
|
61 | 0 | auto rope_scaling_type = params.rope_scaling_type; |
62 | 0 | if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { |
63 | 0 | rope_scaling_type = hparams.rope_scaling_type_train; |
64 | 0 | } |
65 | |
|
66 | 0 | if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) { |
67 | 0 | cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none |
68 | 0 | } |
69 | |
|
70 | 0 | if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set' |
71 | 0 | cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; |
72 | 0 | } |
73 | |
|
74 | 0 | cparams.yarn_attn_factor *= hparams.rope_attn_factor; |
75 | |
|
76 | 0 | if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { |
77 | 0 | if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { |
78 | 0 | cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; |
79 | 0 | } else { |
80 | 0 | cparams.pooling_type = hparams.pooling_type; |
81 | 0 | } |
82 | 0 | } |
83 | |
|
84 | 0 | if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) { |
85 | 0 | cparams.causal_attn = hparams.causal_attn; |
86 | 0 | } else { |
87 | 0 | cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; |
88 | 0 | } |
89 | |
|
90 | 0 | cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; |
91 | | |
92 | | // with causal attention, the batch size is limited by the context size |
93 | 0 | cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; |
94 | | |
95 | | // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask |
96 | | // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) |
97 | | // ref: https://github.com/ggerganov/llama.cpp/pull/5021 |
98 | | // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory |
99 | 0 | if (cparams.n_batch < GGML_KQ_MASK_PAD) { |
100 | 0 | LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); |
101 | 0 | cparams.n_batch = GGML_KQ_MASK_PAD; |
102 | 0 | } |
103 | 0 | cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); |
104 | |
|
105 | 0 | cparams.op_offload = params.op_offload; |
106 | 0 | cparams.kv_unified = params.kv_unified; |
107 | |
|
108 | 0 | { |
109 | 0 | const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); |
110 | 0 | graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; |
111 | |
|
112 | 0 | if (graph_reuse_disable) { |
113 | 0 | LLAMA_LOG_WARN("%s: graph reuse disabled\n", __func__); |
114 | 0 | } |
115 | 0 | } |
116 | | |
117 | | // ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732 |
118 | 0 | cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); |
119 | |
|
120 | 0 | if (cparams.kv_unified) { |
121 | 0 | cparams.n_ctx_seq = cparams.n_ctx; |
122 | 0 | } else { |
123 | 0 | cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max; |
124 | 0 | cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256); |
125 | |
|
126 | 0 | if (cparams.n_ctx_seq == 0) { |
127 | 0 | throw std::runtime_error("n_ctx_seq == 0"); |
128 | 0 | } |
129 | | |
130 | 0 | if (cparams.n_ctx != cparams.n_ctx_seq * cparams.n_seq_max) { |
131 | 0 | cparams.n_ctx = cparams.n_ctx_seq * cparams.n_seq_max; |
132 | 0 | LLAMA_LOG_WARN("%s: n_ctx is not divisible by n_seq_max - rounding down to %u\n", __func__, cparams.n_ctx); |
133 | 0 | } |
134 | 0 | } |
135 | | |
136 | 0 | LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); |
137 | 0 | LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); |
138 | 0 | LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq); |
139 | 0 | LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); |
140 | 0 | LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); |
141 | 0 | LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); |
142 | 0 | LLAMA_LOG_INFO("%s: flash_attn = %s\n", __func__, llama_flash_attn_type_name(params.flash_attn_type)); |
143 | 0 | LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); |
144 | 0 | LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); |
145 | 0 | LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); |
146 | |
|
147 | 0 | if (cparams.n_ctx_seq < hparams.n_ctx_train) { |
148 | 0 | LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", |
149 | 0 | __func__, cparams.n_ctx_seq, hparams.n_ctx_train); |
150 | 0 | } |
151 | |
|
152 | 0 | if (cparams.n_ctx_seq > hparams.n_ctx_train) { |
153 | 0 | LLAMA_LOG_WARN("%s: n_ctx_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", |
154 | 0 | __func__, cparams.n_ctx_seq, hparams.n_ctx_train); |
155 | 0 | } |
156 | |
|
157 | 0 | if (!hparams.vocab_only) { |
158 | | // GPU backends |
159 | 0 | for (auto * dev : model.devices) { |
160 | 0 | ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); |
161 | 0 | if (backend == nullptr) { |
162 | 0 | throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); |
163 | 0 | } |
164 | 0 | backends.emplace_back(backend); |
165 | 0 | } |
166 | | |
167 | | // add ACCEL backends (such as BLAS) |
168 | 0 | for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { |
169 | 0 | ggml_backend_dev_t dev = ggml_backend_dev_get(i); |
170 | 0 | if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { |
171 | 0 | ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); |
172 | 0 | if (backend == nullptr) { |
173 | 0 | throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); |
174 | 0 | } |
175 | 0 | backends.emplace_back(backend); |
176 | 0 | } |
177 | 0 | } |
178 | | |
179 | | // add CPU backend |
180 | 0 | backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); |
181 | 0 | if (backend_cpu == nullptr) { |
182 | 0 | throw std::runtime_error("failed to initialize CPU backend"); |
183 | 0 | } |
184 | 0 | backends.emplace_back(backend_cpu); |
185 | | |
186 | | // create a list of the set_n_threads functions in the backends |
187 | 0 | for (auto & backend : backends) { |
188 | 0 | ggml_backend_dev_t dev = ggml_backend_get_device(backend.get()); |
189 | 0 | ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; |
190 | 0 | if (reg) { |
191 | 0 | auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); |
192 | 0 | if (ggml_backend_set_n_threads_fn) { |
193 | 0 | set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn); |
194 | 0 | } |
195 | 0 | } |
196 | 0 | } |
197 | |
|
198 | 0 | llama_set_abort_callback(this, params.abort_callback, params.abort_callback_data); |
199 | | |
200 | | // graph outputs buffer |
201 | 0 | { |
202 | | // resized during inference when a batch uses more outputs |
203 | 0 | if (output_reserve(params.n_seq_max) < params.n_seq_max) { |
204 | 0 | throw std::runtime_error("failed to reserve initial output buffer"); |
205 | 0 | } |
206 | | |
207 | 0 | LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__, |
208 | 0 | ggml_backend_buffer_name (buf_output.get()), |
209 | 0 | ggml_backend_buffer_get_size(buf_output.get()) / 1024.0 / 1024.0); |
210 | 0 | } |
211 | 0 | } |
212 | | |
213 | | // init the memory module |
214 | 0 | if (!hparams.vocab_only) { |
215 | 0 | llama_memory_params params_mem = { |
216 | 0 | /*.type_k =*/ params.type_k, |
217 | 0 | /*.type_v =*/ params.type_v, |
218 | 0 | /*.swa_full =*/ params.swa_full, |
219 | 0 | }; |
220 | |
|
221 | 0 | memory.reset(model.create_memory(params_mem, cparams)); |
222 | 0 | } |
223 | | |
224 | | // init backends |
225 | 0 | if (!hparams.vocab_only) { |
226 | 0 | LLAMA_LOG_DEBUG("%s: enumerating backends\n", __func__); |
227 | |
|
228 | 0 | backend_buft.clear(); |
229 | 0 | backend_ptrs.clear(); |
230 | |
|
231 | 0 | for (auto & backend : backends) { |
232 | 0 | auto * buft = ggml_backend_get_default_buffer_type(backend.get()); |
233 | 0 | auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); |
234 | |
|
235 | 0 | if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) { |
236 | | // use the host buffer of the first device CPU for faster transfer of the intermediate state |
237 | 0 | auto * dev = model.devices[0]; |
238 | 0 | auto * host_buft = ggml_backend_dev_host_buffer_type(dev); |
239 | 0 | if (host_buft) { |
240 | 0 | buft = host_buft; |
241 | 0 | } |
242 | 0 | } |
243 | |
|
244 | 0 | backend_buft.push_back(buft); |
245 | 0 | backend_ptrs.push_back(backend.get()); |
246 | 0 | } |
247 | |
|
248 | 0 | LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); |
249 | |
|
250 | 0 | const size_t max_nodes = this->graph_max_nodes(); |
251 | |
|
252 | 0 | LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); |
253 | |
|
254 | 0 | gf_res_prev.reset(new llm_graph_result(max_nodes)); |
255 | 0 | gf_res_reserve.reset(new llm_graph_result(max_nodes)); |
256 | | |
257 | | // TODO: move these checks to ggml_backend_sched |
258 | | // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary |
259 | 0 | bool pipeline_parallel = |
260 | 0 | model.n_devices() > 1 && |
261 | 0 | model.params.n_gpu_layers > (int) model.hparams.n_layer && |
262 | 0 | model.params.split_mode == LLAMA_SPLIT_MODE_LAYER && |
263 | 0 | cparams.offload_kqv && |
264 | 0 | !model.has_tensor_overrides(); |
265 | | |
266 | | // pipeline parallelism requires support for async compute and events in all devices |
267 | 0 | if (pipeline_parallel) { |
268 | 0 | for (auto & backend : backends) { |
269 | 0 | auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); |
270 | 0 | if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { |
271 | | // ignore CPU backend |
272 | 0 | continue; |
273 | 0 | } |
274 | 0 | auto * dev = ggml_backend_get_device(backend.get()); |
275 | 0 | ggml_backend_dev_props props; |
276 | 0 | ggml_backend_dev_get_props(dev, &props); |
277 | 0 | if (!props.caps.async || !props.caps.events) { |
278 | | // device does not support async compute or events |
279 | 0 | pipeline_parallel = false; |
280 | 0 | break; |
281 | 0 | } |
282 | 0 | } |
283 | 0 | } |
284 | |
|
285 | 0 | sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); |
286 | |
|
287 | 0 | if (pipeline_parallel) { |
288 | 0 | LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); |
289 | 0 | } |
290 | |
|
291 | 0 | llama_memory_context_ptr mctx; |
292 | 0 | if (memory) { |
293 | 0 | LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); |
294 | 0 | mctx = memory->init_full(); |
295 | 0 | if (!mctx) { |
296 | 0 | throw std::runtime_error("failed to initialize memory module"); |
297 | 0 | } |
298 | 0 | } |
299 | | |
300 | 0 | cross.v_embd.clear(); |
301 | |
|
302 | 0 | const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; |
303 | 0 | const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); |
304 | | |
305 | | // avoid reserving graphs with zero outputs - assume one output per sequence |
306 | 0 | n_outputs = n_seqs; |
307 | |
|
308 | 0 | LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); |
309 | | |
310 | | // resolve automatic Flash Attention use |
311 | 0 | if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { |
312 | 0 | auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); |
313 | 0 | if (!gf) { |
314 | 0 | throw std::runtime_error("failed to split graph for Flash Attention check"); |
315 | 0 | } |
316 | | |
317 | 0 | const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; |
318 | 0 | bool fa_device_mismatch = false; |
319 | 0 | for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { |
320 | 0 | ggml_tensor * n = ggml_graph_node(gf, i); |
321 | 0 | if (n->op != GGML_OP_FLASH_ATTN_EXT) { |
322 | 0 | continue; |
323 | 0 | } |
324 | 0 | ggml_backend_dev_t device_fa = ggml_backend_get_device( |
325 | 0 | ggml_backend_sched_get_tensor_backend(sched.get(), n)); |
326 | | |
327 | | // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer |
328 | 0 | GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); |
329 | 0 | const int il = std::stoi(n->name + prefix_len); |
330 | 0 | ggml_backend_dev_t device_kv = model.dev_layer(il); |
331 | 0 | if (device_fa != device_kv) { |
332 | 0 | LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " |
333 | 0 | "is assigned to device %s (usually due to missing support)\n", |
334 | 0 | __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); |
335 | | // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways |
336 | 0 | fa_device_mismatch = true; |
337 | 0 | break; |
338 | 0 | } |
339 | 0 | } |
340 | 0 | if (fa_device_mismatch) { |
341 | 0 | cparams.flash_attn = false; |
342 | 0 | LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); |
343 | 0 | if (ggml_is_quantized(params.type_v)) { |
344 | 0 | throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); |
345 | 0 | } |
346 | 0 | } else { |
347 | 0 | cparams.flash_attn = true; |
348 | 0 | LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); |
349 | 0 | } |
350 | 0 | } |
351 | | |
352 | | // reserve worst-case graph |
353 | 0 | int n_splits_pp = -1; |
354 | 0 | int n_nodes_pp = -1; |
355 | |
|
356 | 0 | int n_splits_tg = -1; |
357 | 0 | int n_nodes_tg = -1; |
358 | | |
359 | | // reserve pp (prompt processing) graph first so that buffers are only allocated once |
360 | 0 | { |
361 | 0 | auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); |
362 | 0 | if (!gf) { |
363 | 0 | if (pipeline_parallel) { |
364 | 0 | LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); |
365 | 0 | sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); |
366 | 0 | gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); |
367 | 0 | } |
368 | 0 | if (!gf) { |
369 | 0 | throw std::runtime_error("failed to allocate compute pp buffers"); |
370 | 0 | } |
371 | 0 | } |
372 | | |
373 | 0 | n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); |
374 | 0 | n_nodes_pp = ggml_graph_n_nodes(gf); |
375 | 0 | } |
376 | | |
377 | | // reserve with tg (token generation) graph to get the number of splits and nodes |
378 | 0 | { |
379 | 0 | auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get()); |
380 | 0 | if (!gf) { |
381 | 0 | throw std::runtime_error("failed to allocate compute tg buffers"); |
382 | 0 | } |
383 | | |
384 | 0 | n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); |
385 | 0 | n_nodes_tg = ggml_graph_n_nodes(gf); |
386 | 0 | } |
387 | | |
388 | | // reserve again with pp graph to avoid ggml-alloc reallocations during inference |
389 | 0 | { |
390 | | // TODO: not sure if the following graph would be worster case for multi-stream KV caches: |
391 | | // |
392 | | // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); |
393 | | // |
394 | 0 | auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); |
395 | 0 | if (!gf) { |
396 | 0 | throw std::runtime_error("failed to allocate compute pp buffers"); |
397 | 0 | } |
398 | 0 | } |
399 | | |
400 | 0 | for (size_t i = 0; i < backend_ptrs.size(); ++i) { |
401 | 0 | ggml_backend_t backend = backend_ptrs[i]; |
402 | 0 | ggml_backend_buffer_type_t buft = backend_buft[i]; |
403 | 0 | size_t size = ggml_backend_sched_get_buffer_size(sched.get(), backend); |
404 | 0 | if (size > 1) { |
405 | 0 | LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, |
406 | 0 | ggml_backend_buft_name(buft), |
407 | 0 | size / 1024.0 / 1024.0); |
408 | 0 | } |
409 | 0 | } |
410 | |
|
411 | 0 | if (n_nodes_pp == n_nodes_tg) { |
412 | 0 | LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); |
413 | 0 | } else { |
414 | 0 | LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); |
415 | 0 | } |
416 | |
|
417 | 0 | if (n_splits_pp == n_splits_tg) { |
418 | 0 | LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); |
419 | 0 | } else { |
420 | 0 | LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); |
421 | 0 | } |
422 | 0 | } |
423 | 0 | } |
424 | | |
425 | 0 | llama_context::~llama_context() { |
426 | 0 | ggml_opt_free(opt_ctx); |
427 | 0 | } |
428 | | |
429 | 0 | void llama_context::synchronize() { |
430 | 0 | ggml_backend_sched_synchronize(sched.get()); |
431 | | |
432 | | // FIXME: if multiple single tokens are evaluated without a synchronization, |
433 | | // the stats will be added to the prompt evaluation stats |
434 | | // this should only happen when using batch size 1 to evaluate a batch |
435 | | |
436 | | // add the evaluation to the stats |
437 | 0 | if (n_queued_tokens == 1) { |
438 | 0 | if (!cparams.no_perf) { |
439 | 0 | t_eval_us += ggml_time_us() - t_compute_start_us; |
440 | 0 | } |
441 | 0 | n_eval++; |
442 | 0 | } else if (n_queued_tokens > 1) { |
443 | 0 | if (!cparams.no_perf) { |
444 | 0 | t_p_eval_us += ggml_time_us() - t_compute_start_us; |
445 | 0 | } |
446 | 0 | n_p_eval += n_queued_tokens; |
447 | 0 | } |
448 | | |
449 | | // get a more accurate load time, upon first eval |
450 | 0 | if (n_queued_tokens > 0 && !has_evaluated_once) { |
451 | 0 | t_load_us = ggml_time_us() - t_start_us; |
452 | 0 | has_evaluated_once = true; |
453 | 0 | } |
454 | |
|
455 | 0 | n_queued_tokens = 0; |
456 | 0 | t_compute_start_us = 0; |
457 | 0 | } |
458 | | |
459 | 0 | const llama_model & llama_context::get_model() const { |
460 | 0 | return model; |
461 | 0 | } |
462 | | |
463 | 0 | const llama_cparams & llama_context::get_cparams() const { |
464 | 0 | return cparams; |
465 | 0 | } |
466 | | |
467 | 0 | ggml_backend_sched_t llama_context::get_sched() const { |
468 | 0 | return sched.get(); |
469 | 0 | } |
470 | | |
471 | 0 | uint32_t llama_context::n_ctx() const { |
472 | 0 | return cparams.n_ctx; |
473 | 0 | } |
474 | | |
475 | 0 | uint32_t llama_context::n_ctx_seq() const { |
476 | 0 | return cparams.n_ctx_seq; |
477 | 0 | } |
478 | | |
479 | 0 | uint32_t llama_context::n_batch() const { |
480 | 0 | return cparams.n_batch; |
481 | 0 | } |
482 | | |
483 | 0 | uint32_t llama_context::n_ubatch() const { |
484 | 0 | return cparams.n_ubatch; |
485 | 0 | } |
486 | | |
487 | 0 | uint32_t llama_context::n_seq_max() const { |
488 | 0 | return cparams.n_seq_max; |
489 | 0 | } |
490 | | |
491 | 0 | uint32_t llama_context::n_threads() const { |
492 | 0 | return cparams.n_threads; |
493 | 0 | } |
494 | | |
495 | 0 | uint32_t llama_context::n_threads_batch() const { |
496 | 0 | return cparams.n_threads_batch; |
497 | 0 | } |
498 | | |
499 | 0 | llama_memory_t llama_context::get_memory() const { |
500 | 0 | return memory.get(); |
501 | 0 | } |
502 | | |
503 | 0 | bool llama_context::memory_update(bool optimize) { |
504 | 0 | if (!memory) { |
505 | 0 | return false; |
506 | 0 | } |
507 | | |
508 | 0 | { |
509 | 0 | const auto mctx = memory->init_update(this, optimize); |
510 | 0 | switch (mctx->get_status()) { |
511 | 0 | case LLAMA_MEMORY_STATUS_SUCCESS: |
512 | 0 | { |
513 | | // noop |
514 | 0 | } break; |
515 | 0 | case LLAMA_MEMORY_STATUS_NO_UPDATE: |
516 | 0 | { |
517 | | // no updates need to be performed |
518 | 0 | return false; |
519 | 0 | } |
520 | 0 | case LLAMA_MEMORY_STATUS_FAILED_PREPARE: |
521 | 0 | case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: |
522 | 0 | { |
523 | 0 | LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__); |
524 | 0 | return false; |
525 | 0 | } |
526 | 0 | } |
527 | | |
528 | | // reset the previous graph result to make sure that it won't be reused |
529 | | // TODO: change the mctx->apply() to return information if a graph reserve is needed |
530 | | // reset the graph result only if the memory module did reset the scheduler |
531 | 0 | gf_res_prev->reset(); |
532 | |
|
533 | 0 | if (!mctx->apply()) { |
534 | 0 | LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); |
535 | 0 | } |
536 | 0 | } |
537 | | |
538 | | // if the memory module did any computation, we have to reserve a new worst-case graph |
539 | 0 | { |
540 | 0 | const auto mctx = memory->init_full(); |
541 | 0 | if (!mctx) { |
542 | 0 | throw std::runtime_error("failed to initialize memory context"); |
543 | 0 | } |
544 | | |
545 | 0 | const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; |
546 | 0 | const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); |
547 | |
|
548 | 0 | auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); |
549 | 0 | if (!gf) { |
550 | 0 | LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); |
551 | 0 | } |
552 | 0 | } |
553 | | |
554 | 0 | return true; |
555 | 0 | } |
556 | | |
557 | 0 | enum llama_pooling_type llama_context::pooling_type() const { |
558 | 0 | return cparams.pooling_type; |
559 | 0 | } |
560 | | |
561 | 0 | float * llama_context::get_logits() { |
562 | 0 | output_reorder(); |
563 | |
|
564 | 0 | return logits; |
565 | 0 | } |
566 | | |
567 | 0 | float * llama_context::get_logits_ith(int32_t i) { |
568 | 0 | int64_t j = -1; |
569 | |
|
570 | 0 | output_reorder(); |
571 | |
|
572 | 0 | try { |
573 | 0 | if (logits == nullptr) { |
574 | 0 | throw std::runtime_error("no logits"); |
575 | 0 | } |
576 | | |
577 | 0 | if (i < 0) { |
578 | 0 | j = n_outputs + i; |
579 | 0 | if (j < 0) { |
580 | 0 | throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); |
581 | 0 | } |
582 | 0 | } else if ((size_t) i >= output_ids.size()) { |
583 | 0 | throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); |
584 | 0 | } else { |
585 | 0 | j = output_ids[i]; |
586 | 0 | } |
587 | | |
588 | 0 | if (j < 0) { |
589 | 0 | throw std::runtime_error(format("batch.logits[%d] != true", i)); |
590 | 0 | } |
591 | 0 | if (j >= n_outputs) { |
592 | | // This should not happen |
593 | 0 | throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); |
594 | 0 | } |
595 | | |
596 | 0 | return logits + j*model.vocab.n_tokens(); |
597 | 0 | } catch (const std::exception & err) { |
598 | 0 | LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); |
599 | | #ifndef NDEBUG |
600 | | GGML_ABORT("fatal error"); |
601 | | #else |
602 | 0 | return nullptr; |
603 | 0 | #endif |
604 | 0 | } |
605 | 0 | } |
606 | | |
607 | 0 | float * llama_context::get_embeddings() { |
608 | 0 | output_reorder(); |
609 | |
|
610 | 0 | return embd; |
611 | 0 | } |
612 | | |
613 | 0 | float * llama_context::get_embeddings_ith(int32_t i) { |
614 | 0 | int64_t j = -1; |
615 | |
|
616 | 0 | output_reorder(); |
617 | |
|
618 | 0 | try { |
619 | 0 | if (embd == nullptr) { |
620 | 0 | throw std::runtime_error("no embeddings"); |
621 | 0 | } |
622 | | |
623 | 0 | if (i < 0) { |
624 | 0 | j = n_outputs + i; |
625 | 0 | if (j < 0) { |
626 | 0 | throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); |
627 | 0 | } |
628 | 0 | } else if ((size_t) i >= output_ids.size()) { |
629 | 0 | throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); |
630 | 0 | } else { |
631 | 0 | j = output_ids[i]; |
632 | 0 | } |
633 | | |
634 | 0 | if (j < 0) { |
635 | 0 | throw std::runtime_error(format("batch.logits[%d] != true", i)); |
636 | 0 | } |
637 | 0 | if (j >= n_outputs) { |
638 | | // This should not happen |
639 | 0 | throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); |
640 | 0 | } |
641 | | |
642 | 0 | return embd + j*model.hparams.n_embd; |
643 | 0 | } catch (const std::exception & err) { |
644 | 0 | LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); |
645 | | #ifndef NDEBUG |
646 | | GGML_ABORT("fatal error"); |
647 | | #else |
648 | 0 | return nullptr; |
649 | 0 | #endif |
650 | 0 | } |
651 | 0 | } |
652 | | |
653 | 0 | float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { |
654 | 0 | auto it = embd_seq.find(seq_id); |
655 | 0 | if (it == embd_seq.end()) { |
656 | 0 | return nullptr; |
657 | 0 | } |
658 | | |
659 | 0 | return it->second.data(); |
660 | 0 | } |
661 | | |
662 | | void llama_context::attach_threadpool( |
663 | | ggml_threadpool_t threadpool, |
664 | 0 | ggml_threadpool_t threadpool_batch) { |
665 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
666 | |
|
667 | 0 | this->threadpool = threadpool; |
668 | 0 | this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool; |
669 | 0 | } |
670 | | |
671 | 0 | void llama_context::detach_threadpool() { |
672 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
673 | |
|
674 | 0 | this->threadpool = nullptr; |
675 | 0 | this->threadpool_batch = nullptr; |
676 | 0 | } |
677 | | |
678 | 0 | void llama_context::set_n_threads(int32_t n_threads, int32_t n_threads_batch) { |
679 | 0 | LLAMA_LOG_DEBUG("%s: n_threads = %d, n_threads_batch = %d\n", __func__, n_threads, n_threads_batch); |
680 | |
|
681 | 0 | cparams.n_threads = n_threads; |
682 | 0 | cparams.n_threads_batch = n_threads_batch; |
683 | 0 | } |
684 | | |
685 | 0 | void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) { |
686 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
687 | |
|
688 | 0 | this->abort_callback = abort_callback; |
689 | 0 | this->abort_callback_data = abort_callback_data; |
690 | |
|
691 | 0 | for (auto & backend : backends) { |
692 | 0 | auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); |
693 | 0 | auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); |
694 | 0 | if (set_abort_callback_fn) { |
695 | 0 | set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); |
696 | 0 | } |
697 | 0 | } |
698 | 0 | } |
699 | | |
700 | 0 | void llama_context::set_embeddings(bool value) { |
701 | 0 | LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); |
702 | |
|
703 | 0 | cparams.embeddings = value; |
704 | 0 | } |
705 | | |
706 | 0 | void llama_context::set_causal_attn(bool value) { |
707 | 0 | LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); |
708 | |
|
709 | 0 | cparams.causal_attn = value; |
710 | 0 | } |
711 | | |
712 | 0 | void llama_context::set_warmup(bool value) { |
713 | 0 | LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); |
714 | |
|
715 | 0 | cparams.warmup = value; |
716 | 0 | } |
717 | | |
718 | | void llama_context::set_adapter_lora( |
719 | | llama_adapter_lora * adapter, |
720 | 0 | float scale) { |
721 | 0 | LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); |
722 | |
|
723 | 0 | loras[adapter] = scale; |
724 | 0 | } |
725 | | |
726 | | bool llama_context::rm_adapter_lora( |
727 | 0 | llama_adapter_lora * adapter) { |
728 | 0 | LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); |
729 | |
|
730 | 0 | auto pos = loras.find(adapter); |
731 | 0 | if (pos != loras.end()) { |
732 | 0 | loras.erase(pos); |
733 | 0 | return true; |
734 | 0 | } |
735 | | |
736 | 0 | return false; |
737 | 0 | } |
738 | | |
739 | 0 | void llama_context::clear_adapter_lora() { |
740 | 0 | LLAMA_LOG_DEBUG("%s: call\n", __func__); |
741 | |
|
742 | 0 | loras.clear(); |
743 | 0 | } |
744 | | |
745 | | bool llama_context::apply_adapter_cvec( |
746 | | const float * data, |
747 | | size_t len, |
748 | | int32_t n_embd, |
749 | | int32_t il_start, |
750 | 0 | int32_t il_end) { |
751 | 0 | LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); |
752 | |
|
753 | 0 | return cvec.apply(model, data, len, n_embd, il_start, il_end); |
754 | 0 | } |
755 | | |
756 | 0 | llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { |
757 | 0 | if (mctx && !mctx->apply()) { |
758 | 0 | LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); |
759 | 0 | ret = GGML_STATUS_FAILED; |
760 | 0 | return nullptr; |
761 | 0 | } |
762 | | |
763 | 0 | auto * res = gf_res_prev.get(); |
764 | 0 | auto * gf = res->get_gf(); |
765 | | |
766 | | // the new graph parameters |
767 | | // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters |
768 | 0 | const auto gparams = graph_params(res, ubatch, mctx, gtype); |
769 | |
|
770 | 0 | if (!graph_reuse_disable && res->can_reuse(gparams)) { |
771 | | //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); |
772 | |
|
773 | 0 | n_reused++; |
774 | 0 | } else { |
775 | 0 | res->reset(); |
776 | |
|
777 | 0 | ggml_backend_sched_reset(sched.get()); |
778 | 0 | ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); |
779 | | |
780 | | //const auto t_start_us = ggml_time_us(); |
781 | |
|
782 | 0 | gf = model.build_graph(gparams); |
783 | | |
784 | | //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); |
785 | |
|
786 | 0 | if (!gf) { |
787 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); |
788 | 0 | ret = GGML_STATUS_FAILED; |
789 | 0 | return nullptr; |
790 | 0 | } |
791 | | |
792 | 0 | if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { |
793 | 0 | LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); |
794 | 0 | ret = GGML_STATUS_ALLOC_FAILED; |
795 | 0 | return nullptr; |
796 | 0 | } |
797 | 0 | } |
798 | | |
799 | | // set the input data for the input tensors |
800 | 0 | { |
801 | | //const auto t_start_us = ggml_time_us(); |
802 | |
|
803 | 0 | res->set_inputs(&ubatch); |
804 | | |
805 | | //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); |
806 | 0 | } |
807 | |
|
808 | 0 | const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1); |
809 | 0 | if (status != GGML_STATUS_SUCCESS) { |
810 | 0 | LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); |
811 | 0 | ret = status; |
812 | 0 | return nullptr; |
813 | 0 | } |
814 | | |
815 | 0 | ret = GGML_STATUS_SUCCESS; |
816 | |
|
817 | 0 | return res; |
818 | 0 | } |
819 | | |
820 | 0 | int llama_context::encode(const llama_batch & batch_inp) { |
821 | 0 | GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT |
822 | |
|
823 | 0 | if (batch_inp.n_tokens == 0) { |
824 | 0 | LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); |
825 | 0 | return -1; |
826 | 0 | } |
827 | | |
828 | 0 | const auto & hparams = model.hparams; |
829 | |
|
830 | 0 | const int64_t n_embd = hparams.n_embd_inp(); |
831 | 0 | const int64_t n_vocab = model.vocab.n_tokens(); |
832 | | |
833 | | // note: during encode, we always pass the full sequence starting from pos = 0 |
834 | 0 | if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { |
835 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); |
836 | 0 | return -1; |
837 | 0 | } |
838 | | |
839 | 0 | const uint32_t n_tokens = balloc->get_n_tokens(); |
840 | | |
841 | | // [TAG_NO_CACHE_PAD] |
842 | | // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true |
843 | 0 | const llama_ubatch ubatch = balloc->split_simple(n_tokens); |
844 | | |
845 | | // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot |
846 | 0 | GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens"); |
847 | |
|
848 | 0 | if (t_compute_start_us == 0) { |
849 | 0 | t_compute_start_us = ggml_time_us(); |
850 | 0 | } |
851 | | |
852 | | // TODO: this clear of the buffer can easily be forgotten - need something better |
853 | 0 | embd_seq.clear(); |
854 | |
|
855 | 0 | n_queued_tokens += n_tokens; |
856 | | |
857 | | // reserve output buffer |
858 | 0 | if (output_reserve(n_tokens) < n_tokens) { |
859 | 0 | LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); |
860 | 0 | return -2; |
861 | 0 | }; |
862 | |
|
863 | 0 | for (uint32_t i = 0; i < n_tokens; ++i) { |
864 | 0 | output_ids[i] = i; |
865 | 0 | } |
866 | |
|
867 | 0 | n_outputs = n_tokens; |
868 | |
|
869 | 0 | const auto causal_attn_org = cparams.causal_attn; |
870 | | |
871 | | // always use non-causal attention for encoder graphs |
872 | | // TODO: this is a tmp solution until we have a proper way to support enc-dec models |
873 | | // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 |
874 | 0 | cparams.causal_attn = false; |
875 | |
|
876 | 0 | ggml_status status; |
877 | 0 | const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); |
878 | |
|
879 | 0 | cparams.causal_attn = causal_attn_org; |
880 | |
|
881 | 0 | if (!res) { |
882 | 0 | switch (status) { |
883 | 0 | case GGML_STATUS_ABORTED: return 2; |
884 | 0 | case GGML_STATUS_ALLOC_FAILED: return -2; |
885 | 0 | case GGML_STATUS_FAILED: return -3; |
886 | 0 | case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); |
887 | 0 | } |
888 | 0 | } |
889 | | |
890 | 0 | auto * t_logits = res->get_logits(); |
891 | 0 | auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); |
892 | | |
893 | | // extract logits |
894 | 0 | if (logits && t_logits) { |
895 | 0 | ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); |
896 | 0 | GGML_ASSERT(backend_res != nullptr); |
897 | 0 | GGML_ASSERT(logits != nullptr); |
898 | |
|
899 | 0 | ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); |
900 | 0 | } |
901 | | |
902 | | // extract embeddings |
903 | 0 | if (embd && t_embd) { |
904 | 0 | ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); |
905 | 0 | GGML_ASSERT(backend_embd != nullptr); |
906 | |
|
907 | 0 | switch (cparams.pooling_type) { |
908 | 0 | case LLAMA_POOLING_TYPE_NONE: |
909 | 0 | { |
910 | | // extract token embeddings |
911 | 0 | GGML_ASSERT(embd != nullptr); |
912 | |
|
913 | 0 | GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); |
914 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float)); |
915 | 0 | } break; |
916 | 0 | case LLAMA_POOLING_TYPE_MEAN: |
917 | 0 | case LLAMA_POOLING_TYPE_CLS: |
918 | 0 | case LLAMA_POOLING_TYPE_LAST: |
919 | 0 | { |
920 | | // extract sequence embeddings |
921 | 0 | auto & embd_seq_out = embd_seq; |
922 | |
|
923 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
924 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
925 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
926 | |
|
927 | 0 | embd_seq_out[seq_id].resize(n_embd); |
928 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); |
929 | 0 | } |
930 | 0 | } break; |
931 | 0 | case LLAMA_POOLING_TYPE_RANK: |
932 | 0 | { |
933 | | // extract the rerank score - n_cls_out floats per sequence |
934 | 0 | auto & embd_seq_out = embd_seq; |
935 | |
|
936 | 0 | const uint32_t n_cls_out = hparams.n_cls_out; |
937 | |
|
938 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
939 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
940 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
941 | |
|
942 | 0 | embd_seq_out[seq_id].resize(n_cls_out); |
943 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); |
944 | 0 | } |
945 | 0 | } break; |
946 | 0 | case LLAMA_POOLING_TYPE_UNSPECIFIED: |
947 | 0 | { |
948 | 0 | GGML_ABORT("unknown pooling type"); |
949 | 0 | } |
950 | 0 | } |
951 | 0 | } |
952 | | |
953 | | // TODO: hacky solution |
954 | 0 | if (model.arch == LLM_ARCH_T5 && t_embd) { |
955 | | //cross.t_embd = t_embd; |
956 | |
|
957 | 0 | synchronize(); |
958 | |
|
959 | 0 | cross.n_embd = t_embd->ne[0]; |
960 | 0 | cross.n_enc = t_embd->ne[1]; |
961 | 0 | cross.v_embd.resize(cross.n_embd*cross.n_enc); |
962 | 0 | memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); |
963 | |
|
964 | 0 | const auto & batch = balloc->get_batch(); |
965 | | |
966 | | // remember the sequence ids used during the encoding - needed for cross attention later |
967 | 0 | cross.seq_ids_enc.resize(n_tokens); |
968 | 0 | for (uint32_t i = 0; i < n_tokens; i++) { |
969 | 0 | cross.seq_ids_enc[i].clear(); |
970 | |
|
971 | 0 | for (int s = 0; s < batch.n_seq_id[i]; s++) { |
972 | 0 | const llama_seq_id seq_id = batch.seq_id[i][s]; |
973 | |
|
974 | 0 | cross.seq_ids_enc[i].insert(seq_id); |
975 | 0 | } |
976 | 0 | } |
977 | 0 | } |
978 | |
|
979 | 0 | return 0; |
980 | 0 | } |
981 | | |
982 | 0 | int llama_context::decode(const llama_batch & batch_inp) { |
983 | 0 | GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT |
984 | |
|
985 | 0 | if (!memory) { |
986 | 0 | LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); |
987 | 0 | return encode(batch_inp); |
988 | 0 | } |
989 | | |
990 | 0 | if (batch_inp.n_tokens == 0) { |
991 | 0 | LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); |
992 | 0 | return -1; |
993 | 0 | } |
994 | | |
995 | 0 | const auto & vocab = model.vocab; |
996 | 0 | const auto & hparams = model.hparams; |
997 | |
|
998 | 0 | const int64_t n_vocab = vocab.n_tokens(); |
999 | 0 | const int64_t n_embd = hparams.n_embd_inp(); |
1000 | | |
1001 | | // when computing embeddings, all tokens are output |
1002 | 0 | const bool output_all = cparams.embeddings; |
1003 | |
|
1004 | 0 | if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) { |
1005 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); |
1006 | 0 | return -1; |
1007 | 0 | } |
1008 | | |
1009 | 0 | const uint32_t n_tokens_all = balloc->get_n_tokens(); |
1010 | 0 | const uint32_t n_outputs_all = balloc->get_n_outputs(); |
1011 | |
|
1012 | 0 | if (output_all) { |
1013 | | // require that all tokens are output |
1014 | 0 | if (n_outputs_all != n_tokens_all) { |
1015 | 0 | LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n", |
1016 | 0 | __func__, n_outputs_all, n_tokens_all); |
1017 | 0 | return -1; |
1018 | 0 | } |
1019 | 0 | } |
1020 | | |
1021 | 0 | GGML_ASSERT(n_tokens_all <= cparams.n_batch); |
1022 | |
|
1023 | 0 | GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); |
1024 | |
|
1025 | 0 | if (t_compute_start_us == 0) { |
1026 | 0 | t_compute_start_us = ggml_time_us(); |
1027 | 0 | } |
1028 | 0 | n_queued_tokens += n_tokens_all; |
1029 | | |
1030 | | // TODO: this clear of the buffer can easily be forgotten - need something better |
1031 | 0 | embd_seq.clear(); |
1032 | 0 | output_swaps.clear(); |
1033 | |
|
1034 | 0 | bool did_optimize = false; |
1035 | | |
1036 | | // handle any pending shifts/copies |
1037 | 0 | memory_update(false); |
1038 | |
|
1039 | 0 | llama_memory_context_ptr mctx; |
1040 | |
|
1041 | 0 | while (true) { |
1042 | 0 | mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); |
1043 | 0 | if (!mctx) { |
1044 | 0 | return -2; |
1045 | 0 | } |
1046 | | |
1047 | 0 | switch (mctx->get_status()) { |
1048 | 0 | case LLAMA_MEMORY_STATUS_SUCCESS: |
1049 | 0 | { |
1050 | 0 | } break; |
1051 | 0 | case LLAMA_MEMORY_STATUS_NO_UPDATE: |
1052 | 0 | { |
1053 | 0 | LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status()); |
1054 | |
|
1055 | 0 | return -2; |
1056 | 0 | } |
1057 | 0 | case LLAMA_MEMORY_STATUS_FAILED_PREPARE: |
1058 | 0 | { |
1059 | 0 | if (!did_optimize) { |
1060 | 0 | did_optimize = true; |
1061 | |
|
1062 | 0 | if (memory_update(true)) { |
1063 | 0 | LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens()); |
1064 | |
|
1065 | 0 | continue; |
1066 | 0 | } |
1067 | 0 | } |
1068 | | |
1069 | 0 | LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens()); |
1070 | |
|
1071 | 0 | return 1; |
1072 | 0 | } |
1073 | 0 | case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: |
1074 | 0 | { |
1075 | 0 | LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens()); |
1076 | |
|
1077 | 0 | return -2; |
1078 | 0 | } |
1079 | 0 | } |
1080 | | |
1081 | 0 | break; |
1082 | 0 | } |
1083 | | |
1084 | | // reserve output buffer |
1085 | 0 | if (output_reserve(n_outputs_all) < n_outputs_all) { |
1086 | 0 | LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); |
1087 | 0 | return -2; |
1088 | 0 | }; |
1089 | |
|
1090 | 0 | int64_t n_outputs_prev = 0; |
1091 | |
|
1092 | 0 | do { |
1093 | 0 | const auto & ubatch = mctx->get_ubatch(); |
1094 | | |
1095 | | // count the outputs in this ubatch |
1096 | 0 | { |
1097 | 0 | int32_t n_outputs_new = 0; |
1098 | |
|
1099 | 0 | if (n_outputs_all == n_tokens_all) { |
1100 | 0 | n_outputs_new = ubatch.n_tokens; |
1101 | 0 | } else { |
1102 | 0 | for (uint32_t i = 0; i < ubatch.n_tokens; i++) { |
1103 | 0 | n_outputs_new += (int32_t) (ubatch.output[i] != 0); |
1104 | 0 | } |
1105 | 0 | } |
1106 | | |
1107 | | // needs to happen before the graph is built |
1108 | 0 | n_outputs = n_outputs_new; |
1109 | 0 | } |
1110 | |
|
1111 | 0 | ggml_status status; |
1112 | 0 | const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); |
1113 | |
|
1114 | 0 | if (!res) { |
1115 | | // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module |
1116 | 0 | llama_pos pos_min[LLAMA_MAX_SEQ]; |
1117 | 0 | for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
1118 | 0 | pos_min[s] = std::numeric_limits<llama_pos>::max(); |
1119 | 0 | } |
1120 | |
|
1121 | 0 | for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { |
1122 | 0 | const auto & seq_id = ubatch.seq_id[i][0]; |
1123 | |
|
1124 | 0 | pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); |
1125 | 0 | } |
1126 | |
|
1127 | 0 | for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { |
1128 | 0 | if (pos_min[s] == std::numeric_limits<llama_pos>::max()) { |
1129 | 0 | continue; |
1130 | 0 | } |
1131 | | |
1132 | 0 | LLAMA_LOG_WARN("%s: removing memory module entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); |
1133 | |
|
1134 | 0 | memory->seq_rm(s, pos_min[s], -1); |
1135 | 0 | } |
1136 | |
|
1137 | 0 | switch (status) { |
1138 | 0 | case GGML_STATUS_ABORTED: return 2; |
1139 | 0 | case GGML_STATUS_ALLOC_FAILED: return -2; |
1140 | 0 | case GGML_STATUS_FAILED: return -3; |
1141 | 0 | case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); |
1142 | 0 | } |
1143 | 0 | } |
1144 | | |
1145 | | // plot the computation graph in dot format (for debugging purposes) |
1146 | | //if (n_past%100 == 0) { |
1147 | | // ggml_graph_dump_dot(gf, NULL, "llama.dot"); |
1148 | | //} |
1149 | | |
1150 | 0 | auto * t_logits = res->get_logits(); |
1151 | 0 | auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; |
1152 | |
|
1153 | 0 | if (t_embd && res->get_embd_pooled()) { |
1154 | 0 | t_embd = res->get_embd_pooled(); |
1155 | 0 | } |
1156 | | |
1157 | | // extract logits |
1158 | 0 | if (t_logits && n_outputs > 0) { |
1159 | 0 | ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); |
1160 | 0 | GGML_ASSERT(backend_res != nullptr); |
1161 | 0 | GGML_ASSERT(logits != nullptr); |
1162 | |
|
1163 | 0 | float * logits_out = logits + n_outputs_prev*n_vocab; |
1164 | |
|
1165 | 0 | if (n_outputs) { |
1166 | 0 | GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); |
1167 | 0 | GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); |
1168 | 0 | ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); |
1169 | 0 | } |
1170 | 0 | } |
1171 | | |
1172 | | // extract embeddings |
1173 | 0 | if (t_embd && n_outputs > 0) { |
1174 | 0 | ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); |
1175 | 0 | GGML_ASSERT(backend_embd != nullptr); |
1176 | |
|
1177 | 0 | switch (cparams.pooling_type) { |
1178 | 0 | case LLAMA_POOLING_TYPE_NONE: |
1179 | 0 | { |
1180 | | // extract token embeddings |
1181 | 0 | GGML_ASSERT(embd != nullptr); |
1182 | 0 | float * embd_out = embd + n_outputs_prev*n_embd; |
1183 | |
|
1184 | 0 | if (n_outputs) { |
1185 | 0 | GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); |
1186 | 0 | GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); |
1187 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); |
1188 | 0 | } |
1189 | 0 | } break; |
1190 | 0 | case LLAMA_POOLING_TYPE_MEAN: |
1191 | 0 | case LLAMA_POOLING_TYPE_CLS: |
1192 | 0 | case LLAMA_POOLING_TYPE_LAST: |
1193 | 0 | { |
1194 | | // extract sequence embeddings (cleared before processing each batch) |
1195 | 0 | auto & embd_seq_out = embd_seq; |
1196 | |
|
1197 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1198 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1199 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1200 | |
|
1201 | 0 | embd_seq_out[seq_id].resize(n_embd); |
1202 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); |
1203 | 0 | } |
1204 | 0 | } break; |
1205 | 0 | case LLAMA_POOLING_TYPE_RANK: |
1206 | 0 | { |
1207 | | // extract the rerank score - n_cls_out floats per sequence |
1208 | 0 | auto & embd_seq_out = embd_seq; |
1209 | |
|
1210 | 0 | const uint32_t n_cls_out = hparams.n_cls_out; |
1211 | |
|
1212 | 0 | for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { |
1213 | 0 | const llama_seq_id seq_id = ubatch.seq_id_unq[s]; |
1214 | 0 | const int32_t seq_idx = ubatch.seq_idx[seq_id]; |
1215 | |
|
1216 | 0 | embd_seq_out[seq_id].resize(n_cls_out); |
1217 | 0 | ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); |
1218 | 0 | } |
1219 | 0 | } break; |
1220 | 0 | case LLAMA_POOLING_TYPE_UNSPECIFIED: |
1221 | 0 | { |
1222 | 0 | GGML_ABORT("unknown pooling type"); |
1223 | 0 | } |
1224 | 0 | } |
1225 | 0 | } |
1226 | | |
1227 | 0 | n_outputs_prev += n_outputs; |
1228 | 0 | } while (mctx->next()); |
1229 | | |
1230 | | // set to total number of outputs in the batch, for use in llama_get_logits_ith |
1231 | 0 | n_outputs = n_outputs_all; |
1232 | | |
1233 | | // set output mappings |
1234 | 0 | if (n_outputs > 0) { |
1235 | 0 | bool sorted_output = true; |
1236 | |
|
1237 | 0 | auto & out_ids = balloc->get_out_ids(); |
1238 | |
|
1239 | 0 | GGML_ASSERT(out_ids.size() == (size_t) n_outputs); |
1240 | |
|
1241 | 0 | for (int64_t i = 0; i < n_outputs; ++i) { |
1242 | 0 | int64_t out_id = out_ids[i]; |
1243 | 0 | output_ids[out_id] = i; |
1244 | 0 | if (out_id != i) { |
1245 | 0 | sorted_output = false; |
1246 | 0 | } |
1247 | 0 | } |
1248 | | |
1249 | | // make the outputs have the same order they had in the user-provided batch |
1250 | | // note: this is mostly relevant for recurrent models atm |
1251 | 0 | if (!sorted_output) { |
1252 | 0 | GGML_ASSERT((size_t) n_outputs == out_ids.size()); |
1253 | | |
1254 | | // TODO: is there something more efficient which also minimizes swaps? |
1255 | | // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) |
1256 | 0 | for (uint32_t i = 0; i < n_outputs - 1; ++i) { |
1257 | 0 | uint32_t j_min = i; |
1258 | 0 | for (uint32_t j = i + 1; j < n_outputs; ++j) { |
1259 | 0 | if (out_ids[j] < out_ids[j_min]) { |
1260 | 0 | j_min = j; |
1261 | 0 | } |
1262 | 0 | } |
1263 | 0 | if (j_min == i) { |
1264 | 0 | continue; |
1265 | 0 | } |
1266 | 0 | std::swap(out_ids[i], out_ids[j_min]); |
1267 | | |
1268 | | // remember the swaps and apply them lazily upon logits/embeddings access |
1269 | 0 | output_swaps.push_back({ i, j_min }); |
1270 | 0 | } |
1271 | |
|
1272 | 0 | std::fill(output_ids.begin(), output_ids.end(), -1); |
1273 | |
|
1274 | 0 | for (uint32_t i = 0; i < n_outputs; ++i) { |
1275 | 0 | output_ids[out_ids[i]] = i; |
1276 | 0 | } |
1277 | 0 | } |
1278 | 0 | } |
1279 | | |
1280 | | // wait for the computation to finish (automatically done when obtaining the model output) |
1281 | | //synchronize(); |
1282 | |
|
1283 | 0 | return 0; |
1284 | 0 | } |
1285 | | |
1286 | | // |
1287 | | // output |
1288 | | // |
1289 | | |
1290 | 0 | uint32_t llama_context::output_reserve(int32_t n_outputs) { |
1291 | 0 | const auto & hparams = model.hparams; |
1292 | 0 | const auto & vocab = model.vocab; |
1293 | |
|
1294 | 0 | const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max()); |
1295 | |
|
1296 | 0 | const auto n_batch = cparams.n_batch; |
1297 | 0 | const auto n_vocab = vocab.n_tokens(); |
1298 | 0 | const auto n_embd = hparams.n_embd; |
1299 | |
|
1300 | 0 | bool has_logits = true; |
1301 | 0 | bool has_embd = cparams.embeddings; |
1302 | | |
1303 | | // TODO: hacky enc-dec support |
1304 | 0 | if (model.arch == LLM_ARCH_T5) { |
1305 | 0 | has_logits = true; |
1306 | 0 | has_embd = true; |
1307 | 0 | } |
1308 | |
|
1309 | 0 | logits_size = has_logits ? n_vocab*n_outputs_max : 0; |
1310 | 0 | embd_size = has_embd ? n_embd*n_outputs_max : 0; |
1311 | |
|
1312 | 0 | if (output_ids.empty()) { |
1313 | | // init, never resized afterwards |
1314 | 0 | output_ids.resize(n_batch); |
1315 | 0 | } |
1316 | |
|
1317 | 0 | const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; |
1318 | 0 | const size_t new_size = (logits_size + embd_size) * sizeof(float); |
1319 | | |
1320 | | // alloc only when more than the current capacity is required |
1321 | | // TODO: also consider shrinking the buffer |
1322 | 0 | if (!buf_output || prev_size < new_size) { |
1323 | 0 | if (buf_output) { |
1324 | | #ifndef NDEBUG |
1325 | | // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) |
1326 | | LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); |
1327 | | #endif |
1328 | 0 | buf_output = nullptr; |
1329 | 0 | logits = nullptr; |
1330 | 0 | embd = nullptr; |
1331 | 0 | } |
1332 | |
|
1333 | 0 | auto * buft = ggml_backend_cpu_buffer_type(); |
1334 | | // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory |
1335 | 0 | auto * output_dev = model.dev_output(); |
1336 | 0 | auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; |
1337 | 0 | if (output_dev_host_buft) { |
1338 | 0 | buft = output_dev_host_buft; |
1339 | 0 | } |
1340 | 0 | buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size)); |
1341 | 0 | if (buf_output == nullptr) { |
1342 | 0 | LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); |
1343 | 0 | return 0; |
1344 | 0 | } |
1345 | 0 | } |
1346 | | |
1347 | 0 | float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); |
1348 | |
|
1349 | 0 | logits = has_logits ? output_base : nullptr; |
1350 | 0 | embd = has_embd ? output_base + logits_size : nullptr; |
1351 | | |
1352 | | // set all ids as invalid (negative) |
1353 | 0 | std::fill(output_ids.begin(), output_ids.end(), -1); |
1354 | |
|
1355 | 0 | this->n_outputs = 0; |
1356 | |
|
1357 | 0 | return n_outputs_max; |
1358 | 0 | } |
1359 | | |
1360 | 0 | void llama_context::output_reorder() { |
1361 | 0 | const uint64_t n_vocab = model.vocab.n_tokens(); |
1362 | 0 | const uint64_t n_embd = model.hparams.n_embd; |
1363 | |
|
1364 | 0 | for (size_t s = 0; s < output_swaps.size(); ++s) { |
1365 | 0 | const uint64_t i0 = output_swaps[s].i0; |
1366 | 0 | const uint64_t i1 = output_swaps[s].i1; |
1367 | |
|
1368 | 0 | if (logits_size > 0) { |
1369 | 0 | for (uint64_t k = 0; k < n_vocab; k++) { |
1370 | 0 | std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); |
1371 | 0 | } |
1372 | 0 | } |
1373 | |
|
1374 | 0 | if (embd_size > 0) { |
1375 | 0 | for (uint64_t k = 0; k < n_embd; k++) { |
1376 | 0 | std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); |
1377 | 0 | } |
1378 | 0 | } |
1379 | 0 | } |
1380 | |
|
1381 | 0 | output_swaps.clear(); |
1382 | 0 | } |
1383 | | |
1384 | | // |
1385 | | // graph |
1386 | | // |
1387 | | |
1388 | 0 | uint32_t llama_context::graph_max_nodes() const { |
1389 | 0 | return std::max<uint32_t>(1024u, 8u*model.n_tensors()); |
1390 | 0 | } |
1391 | | |
1392 | 0 | llm_graph_result * llama_context::get_gf_res_reserve() const { |
1393 | 0 | return static_cast<llm_graph_result *>(gf_res_reserve.get()); |
1394 | 0 | } |
1395 | | |
1396 | 0 | ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) { |
1397 | 0 | LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); |
1398 | 0 | GGML_ASSERT(n_outputs >= 1); |
1399 | |
|
1400 | 0 | if (n_tokens % n_seqs != 0) { |
1401 | 0 | n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs |
1402 | 0 | n_outputs = std::min(n_outputs, n_tokens); |
1403 | |
|
1404 | 0 | LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); |
1405 | 0 | } |
1406 | |
|
1407 | 0 | ggml_backend_sched_reset(sched.get()); |
1408 | | |
1409 | | // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that |
1410 | 0 | gf_res_prev->reset(); |
1411 | | |
1412 | | // store the n_outputs as it is, and restore it afterwards |
1413 | | // TODO: not sure if needed, might simplify in the future by removing this |
1414 | 0 | const auto save_n_outputs = this->n_outputs; |
1415 | |
|
1416 | 0 | this->n_outputs = n_outputs; |
1417 | |
|
1418 | 0 | llama_batch_allocr balloc(model.hparams.n_pos_per_embd()); |
1419 | 0 | llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); |
1420 | |
|
1421 | 0 | auto * res = gf_res_reserve.get(); |
1422 | |
|
1423 | 0 | const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); |
1424 | |
|
1425 | 0 | res->reset(); |
1426 | |
|
1427 | 0 | auto * gf = model.build_graph(gparams); |
1428 | |
|
1429 | 0 | this->n_outputs = save_n_outputs; |
1430 | | |
1431 | | // initialize scheduler with the specified graph |
1432 | 0 | if (split_only) { |
1433 | 0 | ggml_backend_sched_split_graph(sched.get(), gf); |
1434 | 0 | } else if (!ggml_backend_sched_reserve(sched.get(), gf)) { |
1435 | 0 | LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); |
1436 | 0 | return nullptr; |
1437 | 0 | } |
1438 | | |
1439 | 0 | return gf; |
1440 | 0 | } |
1441 | | |
1442 | | llm_graph_params llama_context::graph_params( |
1443 | | llm_graph_result * res, |
1444 | | const llama_ubatch & ubatch, |
1445 | | const llama_memory_context_i * mctx, |
1446 | 0 | llm_graph_type gtype) const { |
1447 | 0 | return { |
1448 | 0 | /*.arch =*/ model.arch, |
1449 | 0 | /*.hparams =*/ model.hparams, |
1450 | 0 | /*.cparams =*/ cparams, |
1451 | 0 | /*.ubatch =*/ ubatch, |
1452 | 0 | /*.gtype =*/ gtype, |
1453 | 0 | /*.sched =*/ sched.get(), |
1454 | 0 | /*.backend_cpu =*/ backend_cpu, |
1455 | 0 | /*.cvec =*/ &cvec, |
1456 | 0 | /*.loras =*/ &loras, |
1457 | 0 | /*.mctx =*/ mctx, |
1458 | 0 | /*.cross =*/ &cross, |
1459 | 0 | /*.n_outputs =*/ n_outputs, |
1460 | 0 | /*.cb =*/ graph_get_cb(), |
1461 | 0 | /*.res =*/ res, |
1462 | 0 | }; |
1463 | 0 | } |
1464 | | |
1465 | | ggml_status llama_context::graph_compute( |
1466 | | ggml_cgraph * gf, |
1467 | 0 | bool batched) { |
1468 | 0 | int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; |
1469 | 0 | ggml_threadpool_t tp = batched ? threadpool_batch : threadpool; |
1470 | |
|
1471 | 0 | if (backend_cpu != nullptr) { |
1472 | 0 | auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu)); |
1473 | 0 | auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool"); |
1474 | 0 | if (set_threadpool_fn) { |
1475 | 0 | set_threadpool_fn(backend_cpu, tp); |
1476 | 0 | } |
1477 | 0 | } |
1478 | | |
1479 | | // set the number of threads for all the backends |
1480 | 0 | for (const auto & set_n_threads_fn : set_n_threads_fns) { |
1481 | 0 | set_n_threads_fn.second(set_n_threads_fn.first, n_threads); |
1482 | 0 | } |
1483 | |
|
1484 | 0 | auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf); |
1485 | 0 | if (status != GGML_STATUS_SUCCESS) { |
1486 | 0 | LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); |
1487 | 0 | } |
1488 | | |
1489 | | // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched)); |
1490 | |
|
1491 | 0 | return status; |
1492 | 0 | } |
1493 | | |
1494 | 0 | llm_graph_cb llama_context::graph_get_cb() const { |
1495 | 0 | return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { |
1496 | 0 | if (il >= 0) { |
1497 | 0 | ggml_format_name(cur, "%s-%d", name, il); |
1498 | 0 | } else { |
1499 | 0 | ggml_set_name(cur, name); |
1500 | 0 | } |
1501 | |
|
1502 | 0 | if (!cparams.offload_kqv) { |
1503 | 0 | if (strcmp(name, "kqv_merged_cont") == 0) { |
1504 | | // all nodes between the KV store and the attention output are run on the CPU |
1505 | 0 | ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); |
1506 | 0 | } |
1507 | 0 | } |
1508 | | |
1509 | | // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends |
1510 | | // FIXME: fix in ggml_backend_sched |
1511 | 0 | const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer; |
1512 | 0 | if (ubatch.n_tokens < 32 || full_offload) { |
1513 | 0 | if (il != -1 && strcmp(name, "norm") == 0) { |
1514 | 0 | const auto & dev_layer = model.dev_layer(il); |
1515 | 0 | for (const auto & backend : backends) { |
1516 | 0 | if (ggml_backend_get_device(backend.get()) == dev_layer) { |
1517 | 0 | if (ggml_backend_supports_op(backend.get(), cur)) { |
1518 | 0 | ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get()); |
1519 | 0 | } |
1520 | 0 | } |
1521 | 0 | } |
1522 | 0 | } |
1523 | 0 | } |
1524 | 0 | }; |
1525 | 0 | } |
1526 | | |
1527 | | // |
1528 | | // state save/load |
1529 | | // |
1530 | | |
1531 | | class llama_io_write_dummy : public llama_io_write_i { |
1532 | | public: |
1533 | 0 | llama_io_write_dummy() = default; |
1534 | | |
1535 | 0 | void write(const void * /* src */, size_t size) override { |
1536 | 0 | size_written += size; |
1537 | 0 | } |
1538 | | |
1539 | 0 | void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { |
1540 | 0 | size_written += size; |
1541 | 0 | } |
1542 | | |
1543 | 0 | size_t n_bytes() override { |
1544 | 0 | return size_written; |
1545 | 0 | } |
1546 | | |
1547 | | private: |
1548 | | size_t size_written = 0; |
1549 | | }; |
1550 | | |
1551 | | class llama_io_write_buffer : public llama_io_write_i { |
1552 | | public: |
1553 | | llama_io_write_buffer( |
1554 | 0 | uint8_t * p, size_t len) : ptr(p), buf_size(len) {} |
1555 | | |
1556 | 0 | void write(const void * src, size_t size) override { |
1557 | 0 | if (size > buf_size) { |
1558 | 0 | throw std::runtime_error("unexpectedly reached end of buffer"); |
1559 | 0 | } |
1560 | 0 | memcpy(ptr, src, size); |
1561 | 0 | ptr += size; |
1562 | 0 | size_written += size; |
1563 | 0 | buf_size -= size; |
1564 | 0 | } |
1565 | | |
1566 | 0 | void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { |
1567 | 0 | if (size > buf_size) { |
1568 | 0 | throw std::runtime_error("unexpectedly reached end of buffer"); |
1569 | 0 | } |
1570 | 0 | ggml_backend_tensor_get(tensor, ptr, offset, size); |
1571 | 0 | ptr += size; |
1572 | 0 | size_written += size; |
1573 | 0 | buf_size -= size; |
1574 | 0 | } |
1575 | | |
1576 | 0 | size_t n_bytes() override { |
1577 | 0 | return size_written; |
1578 | 0 | } |
1579 | | |
1580 | | private: |
1581 | | uint8_t * ptr; |
1582 | | size_t buf_size = 0; |
1583 | | size_t size_written = 0; |
1584 | | }; |
1585 | | |
1586 | | class llama_io_read_buffer : public llama_io_read_i { |
1587 | | public: |
1588 | 0 | llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} |
1589 | | |
1590 | 0 | const uint8_t * read(size_t size) override { |
1591 | 0 | const uint8_t * base_ptr = ptr; |
1592 | 0 | if (size > buf_size) { |
1593 | 0 | throw std::runtime_error("unexpectedly reached end of buffer"); |
1594 | 0 | } |
1595 | 0 | ptr += size; |
1596 | 0 | size_read += size; |
1597 | 0 | buf_size -= size; |
1598 | 0 | return base_ptr; |
1599 | 0 | } |
1600 | | |
1601 | 0 | void read_to(void * dst, size_t size) override { |
1602 | 0 | memcpy(dst, read(size), size); |
1603 | 0 | } |
1604 | | |
1605 | 0 | size_t n_bytes() override { |
1606 | 0 | return size_read; |
1607 | 0 | } |
1608 | | |
1609 | | private: |
1610 | | const uint8_t * ptr; |
1611 | | size_t buf_size = 0; |
1612 | | size_t size_read = 0; |
1613 | | }; |
1614 | | |
1615 | | class llama_io_write_file : public llama_io_write_i { |
1616 | | public: |
1617 | 0 | llama_io_write_file(llama_file * f) : file(f) {} |
1618 | | |
1619 | 0 | void write(const void * src, size_t size) override { |
1620 | 0 | file->write_raw(src, size); |
1621 | 0 | size_written += size; |
1622 | 0 | } |
1623 | | |
1624 | 0 | void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { |
1625 | 0 | temp_buffer.resize(size); |
1626 | 0 | ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); |
1627 | 0 | write(temp_buffer.data(), temp_buffer.size()); |
1628 | 0 | } |
1629 | | |
1630 | 0 | size_t n_bytes() override { |
1631 | 0 | return size_written; |
1632 | 0 | } |
1633 | | |
1634 | | private: |
1635 | | llama_file * file; |
1636 | | size_t size_written = 0; |
1637 | | std::vector<uint8_t> temp_buffer; |
1638 | | }; |
1639 | | |
1640 | | class llama_io_read_file : public llama_io_read_i { |
1641 | | public: |
1642 | 0 | llama_io_read_file(llama_file * f) : file(f) {} |
1643 | | |
1644 | 0 | void read_to(void * dst, size_t size) override { |
1645 | 0 | file->read_raw(dst, size); |
1646 | 0 | size_read += size; |
1647 | 0 | } |
1648 | | |
1649 | 0 | const uint8_t * read(size_t size) override { |
1650 | 0 | temp_buffer.resize(size); |
1651 | 0 | read_to(temp_buffer.data(), size); |
1652 | 0 | return temp_buffer.data(); |
1653 | 0 | } |
1654 | | |
1655 | 0 | size_t n_bytes() override { |
1656 | 0 | return size_read; |
1657 | 0 | } |
1658 | | |
1659 | | private: |
1660 | | llama_file * file; |
1661 | | size_t size_read = 0; |
1662 | | std::vector<uint8_t> temp_buffer; |
1663 | | }; |
1664 | | |
1665 | 0 | size_t llama_context::state_get_size() { |
1666 | 0 | llama_io_write_dummy io; |
1667 | 0 | try { |
1668 | 0 | return state_write_data(io); |
1669 | 0 | } catch (const std::exception & err) { |
1670 | 0 | LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); |
1671 | 0 | return 0; |
1672 | 0 | } |
1673 | 0 | } |
1674 | | |
1675 | 0 | size_t llama_context::state_get_data(uint8_t * dst, size_t size) { |
1676 | 0 | llama_io_write_buffer io(dst, size); |
1677 | 0 | try { |
1678 | 0 | return state_write_data(io); |
1679 | 0 | } catch (const std::exception & err) { |
1680 | 0 | LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); |
1681 | 0 | return 0; |
1682 | 0 | } |
1683 | 0 | } |
1684 | | |
1685 | 0 | size_t llama_context::state_set_data(const uint8_t * src, size_t size) { |
1686 | 0 | llama_io_read_buffer io(src, size); |
1687 | 0 | try { |
1688 | 0 | return state_read_data(io); |
1689 | 0 | } catch (const std::exception & err) { |
1690 | 0 | LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); |
1691 | 0 | return 0; |
1692 | 0 | } |
1693 | 0 | } |
1694 | | |
1695 | 0 | size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) { |
1696 | 0 | llama_io_write_dummy io; |
1697 | 0 | try { |
1698 | 0 | return state_seq_write_data(io, seq_id, flags); |
1699 | 0 | } catch (const std::exception & err) { |
1700 | 0 | LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); |
1701 | 0 | return 0; |
1702 | 0 | } |
1703 | 0 | } |
1704 | | |
1705 | 0 | size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) { |
1706 | 0 | llama_io_write_buffer io(dst, size); |
1707 | 0 | try { |
1708 | 0 | return state_seq_write_data(io, seq_id, flags); |
1709 | 0 | } catch (const std::exception & err) { |
1710 | 0 | LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); |
1711 | 0 | return 0; |
1712 | 0 | } |
1713 | 0 | } |
1714 | | |
1715 | 0 | size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) { |
1716 | 0 | llama_io_read_buffer io(src, size); |
1717 | 0 | try { |
1718 | 0 | return state_seq_read_data(io, seq_id, flags); |
1719 | 0 | } catch (const std::exception & err) { |
1720 | 0 | LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); |
1721 | 0 | return 0; |
1722 | 0 | } |
1723 | 0 | } |
1724 | | |
1725 | 0 | bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
1726 | 0 | llama_file file(filepath, "rb"); |
1727 | | |
1728 | | // sanity checks |
1729 | 0 | { |
1730 | 0 | const uint32_t magic = file.read_u32(); |
1731 | 0 | const uint32_t version = file.read_u32(); |
1732 | |
|
1733 | 0 | if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { |
1734 | 0 | LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); |
1735 | 0 | return false; |
1736 | 0 | } |
1737 | 0 | } |
1738 | | |
1739 | | // load the prompt |
1740 | 0 | { |
1741 | 0 | const uint32_t n_token_count = file.read_u32(); |
1742 | |
|
1743 | 0 | if (n_token_count > n_token_capacity) { |
1744 | 0 | LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); |
1745 | 0 | return false; |
1746 | 0 | } |
1747 | | |
1748 | 0 | file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); |
1749 | 0 | *n_token_count_out = n_token_count; |
1750 | 0 | } |
1751 | | |
1752 | | // restore the context state |
1753 | 0 | { |
1754 | 0 | const size_t n_state_size_cur = file.size() - file.tell(); |
1755 | |
|
1756 | 0 | llama_io_read_file io( &file); |
1757 | 0 | const size_t n_read = state_read_data(io); |
1758 | |
|
1759 | 0 | if (n_read != n_state_size_cur) { |
1760 | 0 | LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); |
1761 | 0 | return false; |
1762 | 0 | } |
1763 | 0 | } |
1764 | | |
1765 | 0 | return true; |
1766 | 0 | } |
1767 | | |
1768 | 0 | bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) { |
1769 | 0 | llama_file file(filepath, "wb"); |
1770 | |
|
1771 | 0 | file.write_u32(LLAMA_SESSION_MAGIC); |
1772 | 0 | file.write_u32(LLAMA_SESSION_VERSION); |
1773 | | |
1774 | | // save the prompt |
1775 | 0 | file.write_u32((uint32_t) n_token_count); |
1776 | 0 | file.write_raw(tokens, sizeof(llama_token) * n_token_count); |
1777 | | |
1778 | | // save the context state using stream saving |
1779 | 0 | llama_io_write_file io(&file); |
1780 | 0 | state_write_data(io); |
1781 | |
|
1782 | 0 | return true; |
1783 | 0 | } |
1784 | | |
1785 | 0 | size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
1786 | 0 | llama_file file(filepath, "rb"); |
1787 | | |
1788 | | // version checks |
1789 | 0 | { |
1790 | 0 | const uint32_t magic = file.read_u32(); |
1791 | 0 | const uint32_t version = file.read_u32(); |
1792 | |
|
1793 | 0 | if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { |
1794 | 0 | LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); |
1795 | 0 | return 0; |
1796 | 0 | } |
1797 | 0 | } |
1798 | | |
1799 | | // load the prompt |
1800 | 0 | { |
1801 | 0 | const uint32_t n_token_count = file.read_u32(); |
1802 | |
|
1803 | 0 | if (n_token_count > n_token_capacity) { |
1804 | 0 | LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); |
1805 | 0 | return 0; |
1806 | 0 | } |
1807 | | |
1808 | 0 | file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); |
1809 | 0 | *n_token_count_out = n_token_count; |
1810 | 0 | } |
1811 | | |
1812 | | // restore the context state |
1813 | 0 | { |
1814 | 0 | const size_t state_size = file.size() - file.tell(); |
1815 | 0 | llama_io_read_file io(&file); |
1816 | 0 | const size_t nread = state_seq_read_data(io, seq_id, 0); |
1817 | 0 | if (!nread) { |
1818 | 0 | LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); |
1819 | 0 | return 0; |
1820 | 0 | } |
1821 | 0 | GGML_ASSERT(nread <= state_size); |
1822 | 0 | GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); |
1823 | 0 | } |
1824 | | |
1825 | 0 | return file.tell(); |
1826 | 0 | } |
1827 | | |
1828 | 0 | size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) { |
1829 | 0 | llama_file file(filepath, "wb"); |
1830 | |
|
1831 | 0 | file.write_u32(LLAMA_STATE_SEQ_MAGIC); |
1832 | 0 | file.write_u32(LLAMA_STATE_SEQ_VERSION); |
1833 | | |
1834 | | // save the prompt |
1835 | 0 | file.write_u32((uint32_t) n_token_count); |
1836 | 0 | file.write_raw(tokens, sizeof(llama_token) * n_token_count); |
1837 | | |
1838 | | // save the context state using stream saving |
1839 | 0 | llama_io_write_file io(&file); |
1840 | 0 | state_seq_write_data(io, seq_id, 0); |
1841 | |
|
1842 | 0 | const size_t res = file.tell(); |
1843 | 0 | GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); |
1844 | |
|
1845 | 0 | return res; |
1846 | 0 | } |
1847 | | |
1848 | 0 | size_t llama_context::state_write_data(llama_io_write_i & io) { |
1849 | 0 | LLAMA_LOG_DEBUG("%s: writing state\n", __func__); |
1850 | | |
1851 | | // write model info |
1852 | 0 | { |
1853 | 0 | LLAMA_LOG_DEBUG("%s: - writing model info\n", __func__); |
1854 | |
|
1855 | 0 | const std::string arch_str = llm_arch_name(model.arch); |
1856 | 0 | io.write_string(arch_str); |
1857 | | // TODO: add more model-specific info which should prevent loading the session file if not identical |
1858 | 0 | } |
1859 | | |
1860 | | // write output ids |
1861 | 0 | { |
1862 | 0 | LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); |
1863 | |
|
1864 | 0 | const auto n_outputs = this->n_outputs; |
1865 | 0 | const auto & output_ids = this->output_ids; |
1866 | |
|
1867 | 0 | std::vector<int32_t> w_output_pos; |
1868 | |
|
1869 | 0 | w_output_pos.resize(n_outputs); |
1870 | | |
1871 | | // build a more compact representation of the output ids |
1872 | 0 | for (size_t i = 0; i < n_batch(); ++i) { |
1873 | | // map an output id to a position in the batch |
1874 | 0 | int64_t pos = output_ids[i]; |
1875 | 0 | if (pos >= 0) { |
1876 | 0 | GGML_ASSERT(pos < n_outputs); |
1877 | 0 | w_output_pos[pos] = i; |
1878 | 0 | } |
1879 | 0 | } |
1880 | |
|
1881 | 0 | io.write(&n_outputs, sizeof(n_outputs)); |
1882 | |
|
1883 | 0 | if (n_outputs) { |
1884 | 0 | io.write(w_output_pos.data(), n_outputs * sizeof(int32_t)); |
1885 | 0 | } |
1886 | 0 | } |
1887 | | |
1888 | | // write logits |
1889 | 0 | { |
1890 | 0 | LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); |
1891 | |
|
1892 | 0 | const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens()); |
1893 | |
|
1894 | 0 | io.write(&logits_size, sizeof(logits_size)); |
1895 | |
|
1896 | 0 | if (logits_size) { |
1897 | 0 | io.write(logits, logits_size * sizeof(float)); |
1898 | 0 | } |
1899 | 0 | } |
1900 | | |
1901 | | // write embeddings |
1902 | 0 | { |
1903 | 0 | LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__); |
1904 | |
|
1905 | 0 | const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd); |
1906 | |
|
1907 | 0 | io.write(&embd_size, sizeof(embd_size)); |
1908 | |
|
1909 | 0 | if (embd_size) { |
1910 | 0 | io.write(embd, embd_size * sizeof(float)); |
1911 | 0 | } |
1912 | 0 | } |
1913 | |
|
1914 | 0 | if (memory != nullptr) { |
1915 | 0 | LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); |
1916 | 0 | memory->state_write(io); |
1917 | 0 | } |
1918 | |
|
1919 | 0 | return io.n_bytes(); |
1920 | 0 | } |
1921 | | |
1922 | 0 | size_t llama_context::state_read_data(llama_io_read_i & io) { |
1923 | 0 | LLAMA_LOG_DEBUG("%s: reading state\n", __func__); |
1924 | | |
1925 | | // read model info |
1926 | 0 | { |
1927 | 0 | LLAMA_LOG_DEBUG("%s: - reading model info\n", __func__); |
1928 | |
|
1929 | 0 | const std::string cur_arch_str = llm_arch_name(model.arch); |
1930 | |
|
1931 | 0 | std::string arch_str; |
1932 | 0 | io.read_string(arch_str); |
1933 | 0 | if (cur_arch_str != arch_str) { |
1934 | 0 | throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); |
1935 | 0 | } |
1936 | | // TODO: add more info which needs to be identical but which is not verified otherwise |
1937 | 0 | } |
1938 | | |
1939 | | // read output ids |
1940 | 0 | { |
1941 | 0 | LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__); |
1942 | |
|
1943 | 0 | auto n_outputs = this->n_outputs; |
1944 | 0 | io.read_to(&n_outputs, sizeof(n_outputs)); |
1945 | |
|
1946 | 0 | if (n_outputs > output_reserve(n_outputs)) { |
1947 | 0 | throw std::runtime_error("could not reserve outputs"); |
1948 | 0 | } |
1949 | | |
1950 | 0 | std::vector<int32_t> output_pos; |
1951 | |
|
1952 | 0 | if (n_outputs) { |
1953 | 0 | output_pos.resize(n_outputs); |
1954 | 0 | io.read_to(output_pos.data(), n_outputs * sizeof(int32_t)); |
1955 | |
|
1956 | 0 | for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { |
1957 | 0 | int32_t id = output_pos[i]; |
1958 | 0 | if ((uint32_t) id >= n_batch()) { |
1959 | 0 | throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch())); |
1960 | 0 | } |
1961 | 0 | this->output_ids[id] = i; |
1962 | 0 | } |
1963 | | |
1964 | 0 | this->n_outputs = n_outputs; |
1965 | 0 | } |
1966 | 0 | } |
1967 | | |
1968 | | // read logits |
1969 | 0 | { |
1970 | 0 | LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__); |
1971 | |
|
1972 | 0 | uint64_t logits_size; |
1973 | 0 | io.read_to(&logits_size, sizeof(logits_size)); |
1974 | |
|
1975 | 0 | if (this->logits_size < logits_size) { |
1976 | 0 | throw std::runtime_error("logits buffer too small"); |
1977 | 0 | } |
1978 | | |
1979 | 0 | if (logits_size) { |
1980 | 0 | io.read_to(this->logits, logits_size * sizeof(float)); |
1981 | 0 | } |
1982 | 0 | } |
1983 | | |
1984 | | // read embeddings |
1985 | 0 | { |
1986 | 0 | LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__); |
1987 | |
|
1988 | 0 | uint64_t embd_size; |
1989 | 0 | io.read_to(&embd_size, sizeof(embd_size)); |
1990 | |
|
1991 | 0 | if (this->embd_size < embd_size) { |
1992 | 0 | throw std::runtime_error("embeddings buffer too small"); |
1993 | 0 | } |
1994 | | |
1995 | 0 | if (embd_size) { |
1996 | 0 | io.read_to(this->embd, embd_size * sizeof(float)); |
1997 | 0 | } |
1998 | 0 | } |
1999 | | |
2000 | 0 | if (memory) { |
2001 | 0 | LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); |
2002 | |
|
2003 | 0 | memory->state_read(io); |
2004 | 0 | } |
2005 | |
|
2006 | 0 | return io.n_bytes(); |
2007 | 0 | } |
2008 | | |
2009 | 0 | size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2010 | 0 | GGML_UNUSED(seq_id); |
2011 | |
|
2012 | 0 | if (memory) { |
2013 | 0 | memory->state_write(io, seq_id, flags); |
2014 | 0 | } |
2015 | |
|
2016 | 0 | return io.n_bytes(); |
2017 | 0 | } |
2018 | | |
2019 | 0 | size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2020 | 0 | GGML_UNUSED(seq_id); |
2021 | |
|
2022 | 0 | if (memory) { |
2023 | 0 | memory->state_read(io, seq_id, flags); |
2024 | 0 | } |
2025 | |
|
2026 | 0 | return io.n_bytes(); |
2027 | 0 | } |
2028 | | |
2029 | | // |
2030 | | // perf |
2031 | | // |
2032 | | |
2033 | 0 | llama_perf_context_data llama_context::perf_get_data() const { |
2034 | 0 | llama_perf_context_data data = {}; |
2035 | |
|
2036 | 0 | data.t_start_ms = 1e-3 * t_start_us; |
2037 | 0 | data.t_load_ms = 1e-3 * t_load_us; |
2038 | 0 | data.t_p_eval_ms = 1e-3 * t_p_eval_us; |
2039 | 0 | data.t_eval_ms = 1e-3 * t_eval_us; |
2040 | 0 | data.n_p_eval = std::max(1, n_p_eval); |
2041 | 0 | data.n_eval = std::max(1, n_eval); |
2042 | 0 | data.n_reused = std::max(0, n_reused); |
2043 | |
|
2044 | 0 | return data; |
2045 | 0 | } |
2046 | | |
2047 | 0 | void llama_context::perf_reset() { |
2048 | 0 | t_start_us = ggml_time_us(); |
2049 | 0 | t_eval_us = n_eval = 0; |
2050 | 0 | t_p_eval_us = n_p_eval = 0; |
2051 | 0 | n_reused = 0; |
2052 | 0 | } |
2053 | | |
2054 | 0 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const { |
2055 | 0 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret; |
2056 | 0 | for (const auto & buft_size : model.memory_breakdown()) { |
2057 | 0 | ret[buft_size.first].model += buft_size.second; |
2058 | 0 | } |
2059 | 0 | for (const auto & buft_size : memory->memory_breakdown()) { |
2060 | 0 | ret[buft_size.first].context += buft_size.second; |
2061 | 0 | } |
2062 | 0 | for (const auto & backend_ptr : backends) { |
2063 | 0 | ggml_backend_t backend = backend_ptr.get(); |
2064 | 0 | ret[ggml_backend_sched_get_buffer_type(sched.get(), backend)].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend); |
2065 | 0 | } |
2066 | 0 | return ret; |
2067 | 0 | } |
2068 | | |
2069 | | // |
2070 | | // training |
2071 | | // |
2072 | | |
2073 | 0 | static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) { |
2074 | 0 | if (!tensor || tensor->type != GGML_TYPE_F32) { |
2075 | 0 | return; |
2076 | 0 | } |
2077 | 0 | if (!param_filter(tensor, userdata)) { |
2078 | 0 | return; |
2079 | 0 | } |
2080 | 0 | if (strcmp(tensor->name, "token_embd.weight") == 0) { |
2081 | 0 | return; // FIXME |
2082 | 0 | } |
2083 | 0 | if (strcmp(tensor->name, "rope_freqs.weight") == 0) { |
2084 | 0 | return; // FIXME |
2085 | 0 | } |
2086 | 0 | ggml_set_param(tensor); |
2087 | 0 | } |
2088 | | |
2089 | 0 | void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) { |
2090 | 0 | GGML_ASSERT(!opt_ctx); |
2091 | 0 | model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx(); |
2092 | 0 | const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train); |
2093 | 0 | const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); |
2094 | 0 | GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0); |
2095 | 0 | GGML_ASSERT(n_batch % n_ubatch == 0); |
2096 | |
|
2097 | 0 | ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY); |
2098 | 0 | opt_params.opt_period = n_batch / n_ubatch; |
2099 | 0 | opt_params.get_opt_pars = lopt_params.get_opt_pars; |
2100 | 0 | opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud; |
2101 | 0 | opt_params.optimizer = lopt_params.optimizer_type; |
2102 | 0 | opt_ctx = ggml_opt_init(opt_params); |
2103 | |
|
2104 | 0 | llama_opt_param_filter param_filter = lopt_params.param_filter; |
2105 | 0 | void * param_filter_ud = lopt_params.param_filter_ud; |
2106 | | |
2107 | | //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME |
2108 | 0 | llama_set_param(model->type_embd, param_filter, param_filter_ud); |
2109 | 0 | llama_set_param(model->pos_embd, param_filter, param_filter_ud); |
2110 | 0 | llama_set_param(model->tok_norm, param_filter, param_filter_ud); |
2111 | 0 | llama_set_param(model->tok_norm_b, param_filter, param_filter_ud); |
2112 | 0 | llama_set_param(model->output_norm, param_filter, param_filter_ud); |
2113 | 0 | llama_set_param(model->output_norm_b, param_filter, param_filter_ud); |
2114 | 0 | llama_set_param(model->output, param_filter, param_filter_ud); |
2115 | 0 | llama_set_param(model->output_b, param_filter, param_filter_ud); |
2116 | 0 | llama_set_param(model->output_norm_enc, param_filter, param_filter_ud); |
2117 | 0 | llama_set_param(model->cls, param_filter, param_filter_ud); |
2118 | 0 | llama_set_param(model->cls_b, param_filter, param_filter_ud); |
2119 | 0 | llama_set_param(model->cls_out, param_filter, param_filter_ud); |
2120 | 0 | llama_set_param(model->cls_out_b, param_filter, param_filter_ud); |
2121 | |
|
2122 | 0 | for (struct llama_layer & layer : model->layers) { |
2123 | 0 | for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { |
2124 | 0 | llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud); |
2125 | 0 | } |
2126 | 0 | } |
2127 | 0 | } |
2128 | | |
2129 | | void llama_context::opt_epoch_iter( |
2130 | | ggml_opt_dataset_t dataset, |
2131 | | ggml_opt_result_t result, |
2132 | | const std::vector<llama_token> & tokens, |
2133 | | const std::vector<llama_token> & labels_sparse, |
2134 | | llama_batch & batch, |
2135 | | ggml_opt_epoch_callback callback, |
2136 | | bool train, |
2137 | | int64_t idata_in_loop, |
2138 | | int64_t ndata_in_loop, |
2139 | 0 | int64_t t_loop_start) { |
2140 | 0 | GGML_ASSERT(opt_ctx); |
2141 | 0 | const uint32_t n_ctx = llama_model_n_ctx_train(&model); |
2142 | 0 | const uint32_t n_batch = std::min(this->n_batch(), n_ctx); |
2143 | 0 | const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); |
2144 | |
|
2145 | 0 | memory->clear(true); |
2146 | |
|
2147 | 0 | for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { |
2148 | 0 | batch.n_tokens = n_batch; |
2149 | 0 | for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) { |
2150 | 0 | batch.token [pos_batch] = tokens[pos_ctx + pos_batch]; |
2151 | 0 | batch.pos [pos_batch] = pos_ctx + pos_batch; |
2152 | 0 | batch.n_seq_id[pos_batch] = 1; |
2153 | 0 | batch.seq_id [pos_batch][0] = 0; |
2154 | 0 | batch.logits [pos_batch] = true; |
2155 | 0 | } |
2156 | |
|
2157 | 0 | if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd_inp(), cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { |
2158 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); |
2159 | 0 | return; |
2160 | 0 | } |
2161 | | |
2162 | 0 | const uint32_t n_tokens_all = balloc->get_n_tokens(); |
2163 | |
|
2164 | 0 | n_queued_tokens += n_tokens_all; |
2165 | |
|
2166 | 0 | embd_seq.clear(); |
2167 | |
|
2168 | 0 | uint32_t n_outputs_all = n_tokens_all; |
2169 | |
|
2170 | 0 | auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true); |
2171 | 0 | if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { |
2172 | 0 | LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); |
2173 | 0 | break; |
2174 | 0 | } |
2175 | | |
2176 | | // reserve output buffer |
2177 | 0 | if (output_reserve(n_outputs_all) < n_outputs_all) { |
2178 | 0 | LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); |
2179 | 0 | GGML_ABORT("TODO: handle this error"); |
2180 | 0 | }; |
2181 | |
|
2182 | 0 | uint32_t pos_batch = 0; |
2183 | 0 | do { |
2184 | 0 | const auto & ubatch = mctx->get_ubatch(); |
2185 | |
|
2186 | 0 | n_outputs = ubatch.n_tokens; |
2187 | |
|
2188 | 0 | if (!mctx->apply()) { |
2189 | 0 | LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__); |
2190 | 0 | break; |
2191 | 0 | } |
2192 | | |
2193 | 0 | auto * res = gf_res_prev.get(); |
2194 | |
|
2195 | 0 | const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); |
2196 | |
|
2197 | 0 | res->reset(); |
2198 | |
|
2199 | 0 | auto * gf = model.build_graph(gparams); |
2200 | |
|
2201 | 0 | struct ggml_context * ctx_compute_opt; |
2202 | 0 | { |
2203 | 0 | const size_t size_gf = ggml_graph_size(gf); |
2204 | 0 | const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true); |
2205 | 0 | struct ggml_init_params params = { |
2206 | 0 | /*.mem_size =*/ size_meta, |
2207 | 0 | /*.mem_buffer =*/ nullptr, |
2208 | 0 | /*.no_alloc =*/ true, |
2209 | 0 | }; |
2210 | 0 | ctx_compute_opt = ggml_init(params); |
2211 | 0 | } |
2212 | 0 | ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); |
2213 | 0 | ggml_opt_alloc(opt_ctx, train); |
2214 | |
|
2215 | 0 | res->set_inputs(&ubatch); |
2216 | 0 | { |
2217 | 0 | struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); |
2218 | 0 | GGML_ASSERT(labels->ne[1] == n_ubatch); |
2219 | 0 | ggml_set_zero(labels); |
2220 | 0 | const float onef = 1.0f; |
2221 | 0 | for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) { |
2222 | 0 | const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch; |
2223 | 0 | GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]); |
2224 | 0 | ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float)); |
2225 | 0 | } |
2226 | 0 | } |
2227 | 0 | ggml_opt_eval(opt_ctx, result); |
2228 | 0 | if (callback) { |
2229 | 0 | callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start); |
2230 | 0 | } |
2231 | 0 | ggml_free(ctx_compute_opt); |
2232 | |
|
2233 | 0 | pos_batch += ubatch.n_tokens; |
2234 | 0 | } while (mctx->next()); |
2235 | 0 | } |
2236 | 0 | } |
2237 | | |
2238 | | void llama_context::opt_epoch( |
2239 | | ggml_opt_dataset_t dataset, |
2240 | | ggml_opt_result_t result_train, |
2241 | | ggml_opt_result_t result_eval, |
2242 | | int64_t idata_split, |
2243 | | ggml_opt_epoch_callback callback_train, |
2244 | 0 | ggml_opt_epoch_callback callback_eval) { |
2245 | 0 | const uint32_t n_ctx = this->n_ctx(); |
2246 | 0 | const uint32_t n_batch = std::min(cparams.n_batch, n_ctx); |
2247 | 0 | const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch); |
2248 | 0 | const int64_t ndata = ggml_opt_dataset_ndata(dataset); |
2249 | |
|
2250 | 0 | GGML_ASSERT(idata_split >= 0); |
2251 | 0 | GGML_ASSERT(idata_split <= ndata); |
2252 | |
|
2253 | 0 | const uint32_t ubatch_per_ctx = n_ctx / n_ubatch; |
2254 | |
|
2255 | 0 | struct llama_batch batch = llama_batch_init(n_batch, 0, 1); |
2256 | 0 | std::vector<llama_token> tokens(n_ctx); |
2257 | 0 | std::vector<llama_token> labels_sparse(n_ctx); |
2258 | |
|
2259 | 0 | int64_t idata = 0; |
2260 | |
|
2261 | 0 | int64_t t_loop_start = ggml_time_us(); |
2262 | 0 | int64_t ndata_in_loop = idata_split*ubatch_per_ctx; |
2263 | 0 | for (; idata < idata_split; ++idata) { |
2264 | 0 | constexpr bool train = true; |
2265 | 0 | const int64_t idata_in_loop = idata*ubatch_per_ctx; |
2266 | |
|
2267 | 0 | ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); |
2268 | 0 | opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch, |
2269 | 0 | callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start); |
2270 | 0 | } |
2271 | |
|
2272 | 0 | t_loop_start = ggml_time_us(); |
2273 | 0 | ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx; |
2274 | 0 | for (; idata < ndata; ++idata) { |
2275 | 0 | constexpr bool train = false; |
2276 | 0 | const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx; |
2277 | |
|
2278 | 0 | ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); |
2279 | 0 | opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch, |
2280 | 0 | callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start); |
2281 | 0 | } |
2282 | |
|
2283 | 0 | llama_batch_free(batch); |
2284 | 0 | } |
2285 | | |
2286 | | // |
2287 | | // interface implementation |
2288 | | // |
2289 | | |
2290 | 0 | llama_context_params llama_context_default_params() { |
2291 | 0 | llama_context_params result = { |
2292 | 0 | /*.n_ctx =*/ 512, |
2293 | 0 | /*.n_batch =*/ 2048, |
2294 | 0 | /*.n_ubatch =*/ 512, |
2295 | 0 | /*.n_seq_max =*/ 1, |
2296 | 0 | /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default |
2297 | 0 | /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, |
2298 | 0 | /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, |
2299 | 0 | /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, |
2300 | 0 | /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, |
2301 | 0 | /*.flash_attn_type =*/ LLAMA_FLASH_ATTN_TYPE_AUTO, |
2302 | 0 | /*.rope_freq_base =*/ 0.0f, |
2303 | 0 | /*.rope_freq_scale =*/ 0.0f, |
2304 | 0 | /*.yarn_ext_factor =*/ -1.0f, |
2305 | 0 | /*.yarn_attn_factor =*/ -1.0f, |
2306 | 0 | /*.yarn_beta_fast =*/ -1.0f, |
2307 | 0 | /*.yarn_beta_slow =*/ -1.0f, |
2308 | 0 | /*.yarn_orig_ctx =*/ 0, |
2309 | 0 | /*.defrag_thold =*/ -1.0f, |
2310 | 0 | /*.cb_eval =*/ nullptr, |
2311 | 0 | /*.cb_eval_user_data =*/ nullptr, |
2312 | 0 | /*.type_k =*/ GGML_TYPE_F16, |
2313 | 0 | /*.type_v =*/ GGML_TYPE_F16, |
2314 | 0 | /*.abort_callback =*/ nullptr, |
2315 | 0 | /*.abort_callback_data =*/ nullptr, |
2316 | 0 | /*.embeddings =*/ false, |
2317 | 0 | /*.offload_kqv =*/ true, |
2318 | 0 | /*.no_perf =*/ true, |
2319 | 0 | /*.op_offload =*/ true, |
2320 | 0 | /*.swa_full =*/ true, |
2321 | 0 | /*.kv_unified =*/ false, |
2322 | 0 | }; |
2323 | |
|
2324 | 0 | return result; |
2325 | 0 | } |
2326 | | |
2327 | | llama_context * llama_init_from_model( |
2328 | | llama_model * model, |
2329 | 0 | llama_context_params params) { |
2330 | 0 | if (!model) { |
2331 | 0 | LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__); |
2332 | 0 | return nullptr; |
2333 | 0 | } |
2334 | | |
2335 | 0 | if (params.n_batch == 0 && params.n_ubatch == 0) { |
2336 | 0 | LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__); |
2337 | 0 | return nullptr; |
2338 | 0 | } |
2339 | | |
2340 | 0 | if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) { |
2341 | 0 | LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__); |
2342 | 0 | return nullptr; |
2343 | 0 | } |
2344 | | |
2345 | 0 | if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && model->arch == LLM_ARCH_GROK) { |
2346 | 0 | LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); |
2347 | 0 | params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; |
2348 | 0 | } |
2349 | |
|
2350 | 0 | if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { |
2351 | 0 | const uint32_t blck_size = ggml_blck_size(params.type_k); |
2352 | 0 | if (model->hparams.n_embd_head_k % blck_size != 0) { |
2353 | 0 | LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", |
2354 | 0 | __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k); |
2355 | 0 | return nullptr; |
2356 | 0 | } |
2357 | 0 | } |
2358 | | |
2359 | 0 | if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { |
2360 | 0 | const uint32_t blck_size = ggml_blck_size(params.type_v); |
2361 | 0 | if (model->hparams.n_embd_head_v % blck_size != 0) { |
2362 | 0 | LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n", |
2363 | 0 | __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v); |
2364 | 0 | return nullptr; |
2365 | 0 | } |
2366 | 0 | } |
2367 | | |
2368 | 0 | if (ggml_is_quantized(params.type_v) && params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_DISABLED) { |
2369 | 0 | LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); |
2370 | 0 | return nullptr; |
2371 | 0 | } |
2372 | | |
2373 | 0 | if (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED && |
2374 | 0 | params.pooling_type != model->hparams.pooling_type) { |
2375 | | //user-specified pooling-type is different from the model default |
2376 | 0 | LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__, |
2377 | 0 | model->hparams.pooling_type, params.pooling_type); |
2378 | 0 | } |
2379 | |
|
2380 | 0 | try { |
2381 | 0 | auto * ctx = new llama_context(*model, params); |
2382 | 0 | return ctx; |
2383 | 0 | } catch (const std::exception & err) { |
2384 | 0 | LLAMA_LOG_ERROR("%s: failed to initialize the context: %s\n", __func__, err.what()); |
2385 | 0 | } |
2386 | | |
2387 | 0 | return nullptr; |
2388 | 0 | } |
2389 | | |
2390 | | // deprecated |
2391 | | llama_context * llama_new_context_with_model( |
2392 | | llama_model * model, |
2393 | 0 | llama_context_params params) { |
2394 | 0 | return llama_init_from_model(model, params); |
2395 | 0 | } |
2396 | | |
2397 | 0 | void llama_free(llama_context * ctx) { |
2398 | 0 | delete ctx; |
2399 | 0 | } |
2400 | | |
2401 | 0 | uint32_t llama_n_ctx(const llama_context * ctx) { |
2402 | 0 | return ctx->n_ctx(); |
2403 | 0 | } |
2404 | | |
2405 | 0 | uint32_t llama_n_ctx_seq(const llama_context * ctx) { |
2406 | 0 | return ctx->n_ctx_seq(); |
2407 | 0 | } |
2408 | | |
2409 | 0 | uint32_t llama_n_batch(const llama_context * ctx) { |
2410 | 0 | return ctx->n_batch(); |
2411 | 0 | } |
2412 | | |
2413 | 0 | uint32_t llama_n_ubatch(const llama_context * ctx) { |
2414 | 0 | return ctx->n_ubatch(); |
2415 | 0 | } |
2416 | | |
2417 | 0 | uint32_t llama_n_seq_max(const llama_context * ctx) { |
2418 | 0 | return ctx->n_seq_max(); |
2419 | 0 | } |
2420 | | |
2421 | 0 | const llama_model * llama_get_model(const llama_context * ctx) { |
2422 | 0 | return &ctx->get_model(); |
2423 | 0 | } |
2424 | | |
2425 | 0 | enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { |
2426 | 0 | return ctx->pooling_type(); |
2427 | 0 | } |
2428 | | |
2429 | | void llama_attach_threadpool( |
2430 | | llama_context * ctx, |
2431 | | ggml_threadpool_t threadpool, |
2432 | 0 | ggml_threadpool_t threadpool_batch) { |
2433 | 0 | ctx->attach_threadpool(threadpool, threadpool_batch); |
2434 | 0 | } |
2435 | | |
2436 | 0 | void llama_detach_threadpool(llama_context * ctx) { |
2437 | 0 | ctx->detach_threadpool(); |
2438 | 0 | } |
2439 | | |
2440 | 0 | void llama_set_n_threads(llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) { |
2441 | 0 | ctx->set_n_threads(n_threads, n_threads_batch); |
2442 | 0 | } |
2443 | | |
2444 | 0 | int32_t llama_n_threads(llama_context * ctx) { |
2445 | 0 | return ctx->n_threads(); |
2446 | 0 | } |
2447 | | |
2448 | 0 | int32_t llama_n_threads_batch(llama_context * ctx) { |
2449 | 0 | return ctx->n_threads_batch(); |
2450 | 0 | } |
2451 | | |
2452 | 0 | void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { |
2453 | 0 | ctx->set_abort_callback(abort_callback, abort_callback_data); |
2454 | 0 | } |
2455 | | |
2456 | 0 | void llama_set_embeddings(llama_context * ctx, bool embeddings) { |
2457 | 0 | ctx->set_embeddings(embeddings); |
2458 | 0 | } |
2459 | | |
2460 | 0 | void llama_set_causal_attn(llama_context * ctx, bool causal_attn) { |
2461 | 0 | ctx->set_causal_attn(causal_attn); |
2462 | 0 | } |
2463 | | |
2464 | 0 | void llama_set_warmup(llama_context * ctx, bool warmup) { |
2465 | 0 | ctx->set_warmup(warmup); |
2466 | 0 | } |
2467 | | |
2468 | 0 | void llama_synchronize(llama_context * ctx) { |
2469 | 0 | ctx->synchronize(); |
2470 | 0 | } |
2471 | | |
2472 | 0 | float * llama_get_logits(llama_context * ctx) { |
2473 | 0 | ctx->synchronize(); |
2474 | |
|
2475 | 0 | return ctx->get_logits(); |
2476 | 0 | } |
2477 | | |
2478 | 0 | float * llama_get_logits_ith(llama_context * ctx, int32_t i) { |
2479 | 0 | ctx->synchronize(); |
2480 | |
|
2481 | 0 | return ctx->get_logits_ith(i); |
2482 | 0 | } |
2483 | | |
2484 | 0 | float * llama_get_embeddings(llama_context * ctx) { |
2485 | 0 | ctx->synchronize(); |
2486 | |
|
2487 | 0 | return ctx->get_embeddings(); |
2488 | 0 | } |
2489 | | |
2490 | 0 | float * llama_get_embeddings_ith(llama_context * ctx, int32_t i) { |
2491 | 0 | ctx->synchronize(); |
2492 | |
|
2493 | 0 | return ctx->get_embeddings_ith(i); |
2494 | 0 | } |
2495 | | |
2496 | 0 | float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { |
2497 | 0 | ctx->synchronize(); |
2498 | |
|
2499 | 0 | return ctx->get_embeddings_seq(seq_id); |
2500 | 0 | } |
2501 | | |
2502 | | // llama adapter API |
2503 | | |
2504 | | int32_t llama_set_adapter_lora( |
2505 | | llama_context * ctx, |
2506 | | llama_adapter_lora * adapter, |
2507 | 0 | float scale) { |
2508 | 0 | ctx->set_adapter_lora(adapter, scale); |
2509 | |
|
2510 | 0 | return 0; |
2511 | 0 | } |
2512 | | |
2513 | | int32_t llama_rm_adapter_lora( |
2514 | | llama_context * ctx, |
2515 | 0 | llama_adapter_lora * adapter) { |
2516 | 0 | bool res = ctx->rm_adapter_lora(adapter); |
2517 | |
|
2518 | 0 | return res ? 0 : -1; |
2519 | 0 | } |
2520 | | |
2521 | 0 | void llama_clear_adapter_lora(llama_context * ctx) { |
2522 | 0 | ctx->clear_adapter_lora(); |
2523 | 0 | } |
2524 | | |
2525 | | int32_t llama_apply_adapter_cvec( |
2526 | | llama_context * ctx, |
2527 | | const float * data, |
2528 | | size_t len, |
2529 | | int32_t n_embd, |
2530 | | int32_t il_start, |
2531 | 0 | int32_t il_end) { |
2532 | 0 | bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end); |
2533 | |
|
2534 | 0 | return res ? 0 : -1; |
2535 | 0 | } |
2536 | | |
2537 | | // |
2538 | | // memory |
2539 | | // |
2540 | | |
2541 | 0 | llama_memory_t llama_get_memory(const struct llama_context * ctx) { |
2542 | 0 | return ctx->get_memory(); |
2543 | 0 | } |
2544 | | |
2545 | 0 | void llama_memory_clear(llama_memory_t mem, bool data) { |
2546 | 0 | if (!mem) { |
2547 | 0 | return; |
2548 | 0 | } |
2549 | | |
2550 | 0 | mem->clear(data); |
2551 | 0 | } |
2552 | | |
2553 | | bool llama_memory_seq_rm( |
2554 | | llama_memory_t mem, |
2555 | | llama_seq_id seq_id, |
2556 | | llama_pos p0, |
2557 | 0 | llama_pos p1) { |
2558 | 0 | if (!mem) { |
2559 | 0 | return true; |
2560 | 0 | } |
2561 | | |
2562 | 0 | return mem->seq_rm(seq_id, p0, p1); |
2563 | 0 | } |
2564 | | |
2565 | | void llama_memory_seq_cp( |
2566 | | llama_memory_t mem, |
2567 | | llama_seq_id seq_id_src, |
2568 | | llama_seq_id seq_id_dst, |
2569 | | llama_pos p0, |
2570 | 0 | llama_pos p1) { |
2571 | 0 | if (!mem) { |
2572 | 0 | return; |
2573 | 0 | } |
2574 | | |
2575 | 0 | mem->seq_cp(seq_id_src, seq_id_dst, p0, p1); |
2576 | 0 | } |
2577 | | |
2578 | | void llama_memory_seq_keep( |
2579 | | llama_memory_t mem, |
2580 | 0 | llama_seq_id seq_id) { |
2581 | 0 | if (!mem) { |
2582 | 0 | return; |
2583 | 0 | } |
2584 | | |
2585 | 0 | mem->seq_keep(seq_id); |
2586 | 0 | } |
2587 | | |
2588 | | void llama_memory_seq_add( |
2589 | | llama_memory_t mem, |
2590 | | llama_seq_id seq_id, |
2591 | | llama_pos p0, |
2592 | | llama_pos p1, |
2593 | 0 | llama_pos delta) { |
2594 | 0 | if (!mem) { |
2595 | 0 | return; |
2596 | 0 | } |
2597 | | |
2598 | 0 | mem->seq_add(seq_id, p0, p1, delta); |
2599 | 0 | } |
2600 | | |
2601 | | void llama_memory_seq_div( |
2602 | | llama_memory_t mem, |
2603 | | llama_seq_id seq_id, |
2604 | | llama_pos p0, |
2605 | | llama_pos p1, |
2606 | 0 | int d) { |
2607 | 0 | if (!mem) { |
2608 | 0 | return; |
2609 | 0 | } |
2610 | | |
2611 | 0 | mem->seq_div(seq_id, p0, p1, d); |
2612 | 0 | } |
2613 | | |
2614 | | llama_pos llama_memory_seq_pos_min( |
2615 | | llama_memory_t mem, |
2616 | 0 | llama_seq_id seq_id) { |
2617 | 0 | if (!mem) { |
2618 | 0 | return -1; |
2619 | 0 | } |
2620 | | |
2621 | 0 | return mem->seq_pos_min(seq_id); |
2622 | 0 | } |
2623 | | |
2624 | | llama_pos llama_memory_seq_pos_max( |
2625 | | llama_memory_t mem, |
2626 | 0 | llama_seq_id seq_id) { |
2627 | 0 | if (!mem) { |
2628 | 0 | return -1; |
2629 | 0 | } |
2630 | | |
2631 | 0 | return mem->seq_pos_max(seq_id); |
2632 | 0 | } |
2633 | | |
2634 | 0 | bool llama_memory_can_shift(llama_memory_t mem) { |
2635 | 0 | if (!mem) { |
2636 | 0 | return false; |
2637 | 0 | } |
2638 | | |
2639 | 0 | return mem->get_can_shift(); |
2640 | 0 | } |
2641 | | |
2642 | | // llama state API |
2643 | | |
2644 | | // deprecated |
2645 | 0 | size_t llama_get_state_size(llama_context * ctx) { |
2646 | 0 | return llama_state_get_size(ctx); |
2647 | 0 | } |
2648 | | |
2649 | | // deprecated |
2650 | 0 | size_t llama_copy_state_data(llama_context * ctx, uint8_t * dst) { |
2651 | 0 | return llama_state_get_data(ctx, dst, -1); |
2652 | 0 | } |
2653 | | |
2654 | | // deprecated |
2655 | 0 | size_t llama_set_state_data(llama_context * ctx, const uint8_t * src) { |
2656 | 0 | return llama_state_set_data(ctx, src, -1); |
2657 | 0 | } |
2658 | | |
2659 | | // deprecated |
2660 | 0 | bool llama_load_session_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
2661 | 0 | return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); |
2662 | 0 | } |
2663 | | |
2664 | | // deprecated |
2665 | 0 | bool llama_save_session_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { |
2666 | 0 | return llama_state_save_file(ctx, path_session, tokens, n_token_count); |
2667 | 0 | } |
2668 | | |
2669 | | // Returns the *actual* size of the state. |
2670 | | // Intended to be used when saving to state to a buffer. |
2671 | 0 | size_t llama_state_get_size(llama_context * ctx) { |
2672 | 0 | return ctx->state_get_size(); |
2673 | 0 | } |
2674 | | |
2675 | 0 | size_t llama_state_get_data(llama_context * ctx, uint8_t * dst, size_t size) { |
2676 | 0 | ctx->synchronize(); |
2677 | |
|
2678 | 0 | return ctx->state_get_data(dst, size); |
2679 | 0 | } |
2680 | | |
2681 | | // Sets the state reading from the specified source address |
2682 | 0 | size_t llama_state_set_data(llama_context * ctx, const uint8_t * src, size_t size) { |
2683 | 0 | ctx->synchronize(); |
2684 | |
|
2685 | 0 | return ctx->state_set_data(src, size); |
2686 | 0 | } |
2687 | | |
2688 | 0 | bool llama_state_load_file(llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
2689 | 0 | ctx->synchronize(); |
2690 | |
|
2691 | 0 | try { |
2692 | 0 | return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out); |
2693 | 0 | } catch (const std::exception & err) { |
2694 | 0 | LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what()); |
2695 | 0 | return false; |
2696 | 0 | } |
2697 | 0 | } |
2698 | | |
2699 | 0 | bool llama_state_save_file(llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { |
2700 | 0 | ctx->synchronize(); |
2701 | |
|
2702 | 0 | try { |
2703 | 0 | return ctx->state_save_file(path_session, tokens, n_token_count); |
2704 | 0 | } catch (const std::exception & err) { |
2705 | 0 | LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what()); |
2706 | 0 | return false; |
2707 | 0 | } |
2708 | 0 | } |
2709 | | |
2710 | 0 | size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) { |
2711 | 0 | return llama_state_seq_get_size_ext(ctx, seq_id, 0); |
2712 | 0 | } |
2713 | | |
2714 | 0 | size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { |
2715 | 0 | return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0); |
2716 | 0 | } |
2717 | | |
2718 | 0 | size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { |
2719 | 0 | return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0); |
2720 | 0 | } |
2721 | | |
2722 | 0 | size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2723 | 0 | return ctx->state_seq_get_size(seq_id, flags); |
2724 | 0 | } |
2725 | | |
2726 | 0 | size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2727 | 0 | ctx->synchronize(); |
2728 | |
|
2729 | 0 | return ctx->state_seq_get_data(seq_id, dst, size, flags); |
2730 | 0 | } |
2731 | | |
2732 | 0 | size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { |
2733 | 0 | ctx->synchronize(); |
2734 | |
|
2735 | 0 | return ctx->state_seq_set_data(seq_id, src, size, flags); |
2736 | 0 | } |
2737 | | |
2738 | 0 | size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { |
2739 | 0 | ctx->synchronize(); |
2740 | |
|
2741 | 0 | try { |
2742 | 0 | return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count); |
2743 | 0 | } catch (const std::exception & err) { |
2744 | 0 | LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what()); |
2745 | 0 | return 0; |
2746 | 0 | } |
2747 | 0 | } |
2748 | | |
2749 | 0 | size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { |
2750 | 0 | ctx->synchronize(); |
2751 | |
|
2752 | 0 | try { |
2753 | 0 | return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out); |
2754 | 0 | } catch (const std::exception & err) { |
2755 | 0 | LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what()); |
2756 | 0 | return 0; |
2757 | 0 | } |
2758 | 0 | } |
2759 | | |
2760 | | /// |
2761 | | |
2762 | | int32_t llama_encode( |
2763 | | llama_context * ctx, |
2764 | 0 | llama_batch batch) { |
2765 | 0 | const int ret = ctx->encode(batch); |
2766 | 0 | if (ret != 0) { |
2767 | 0 | LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); |
2768 | 0 | } |
2769 | |
|
2770 | 0 | return ret; |
2771 | 0 | } |
2772 | | |
2773 | | int32_t llama_decode( |
2774 | | llama_context * ctx, |
2775 | 0 | llama_batch batch) { |
2776 | 0 | const int ret = ctx->decode(batch); |
2777 | 0 | if (ret != 0 && ret != 1) { |
2778 | 0 | LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); |
2779 | 0 | } |
2780 | |
|
2781 | 0 | return ret; |
2782 | 0 | } |
2783 | | |
2784 | | // |
2785 | | // perf |
2786 | | // |
2787 | | |
2788 | 0 | llama_perf_context_data llama_perf_context(const llama_context * ctx) { |
2789 | 0 | llama_perf_context_data data = {}; |
2790 | |
|
2791 | 0 | if (ctx == nullptr) { |
2792 | 0 | return data; |
2793 | 0 | } |
2794 | | |
2795 | 0 | data = ctx->perf_get_data(); |
2796 | |
|
2797 | 0 | return data; |
2798 | 0 | } |
2799 | | |
2800 | 0 | void llama_perf_context_print(const llama_context * ctx) { |
2801 | 0 | const auto data = llama_perf_context(ctx); |
2802 | |
|
2803 | 0 | const double t_end_ms = 1e-3 * ggml_time_us(); |
2804 | |
|
2805 | 0 | LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, data.t_load_ms); |
2806 | 0 | LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", |
2807 | 0 | __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval); |
2808 | 0 | LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", |
2809 | 0 | __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval); |
2810 | 0 | LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval)); |
2811 | 0 | LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused); |
2812 | 0 | } |
2813 | | |
2814 | 0 | void llama_perf_context_reset(llama_context * ctx) { |
2815 | 0 | ctx->perf_reset(); |
2816 | 0 | } |
2817 | | |
2818 | 0 | void llama_memory_breakdown_print(const struct llama_context * ctx) { |
2819 | 0 | const std::vector<ggml_backend_dev_t> & devices = ctx->get_model().devices; |
2820 | |
|
2821 | 0 | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown(); |
2822 | |
|
2823 | 0 | std::vector<std::array<std::string, 9>> table_data; |
2824 | 0 | table_data.reserve(devices.size()); |
2825 | 0 | const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n"; |
2826 | 0 | const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n"; |
2827 | 0 | const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n"; |
2828 | |
|
2829 | 0 | table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"}); |
2830 | |
|
2831 | 0 | constexpr size_t MiB = 1024 * 1024; |
2832 | 0 | const std::vector<std::string> desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "}; |
2833 | | |
2834 | | // track seen buffer types to avoid double counting: |
2835 | 0 | std::set<ggml_backend_buffer_type_t> seen_buffer_types; |
2836 | | |
2837 | | // accumulative memory breakdown for each device and for host: |
2838 | 0 | std::vector<llama_memory_breakdown_data> mb_dev(devices.size()); |
2839 | 0 | llama_memory_breakdown_data mb_host; |
2840 | |
|
2841 | 0 | for (const auto & buft_mb : memory_breakdown) { |
2842 | 0 | ggml_backend_buffer_type_t buft = buft_mb.first; |
2843 | 0 | const llama_memory_breakdown_data & mb = buft_mb.second; |
2844 | 0 | if (ggml_backend_buft_is_host(buft)) { |
2845 | 0 | mb_host.model += mb.model; |
2846 | 0 | mb_host.context += mb.context; |
2847 | 0 | mb_host.compute += mb.compute; |
2848 | 0 | seen_buffer_types.insert(buft); |
2849 | 0 | continue; |
2850 | 0 | } |
2851 | 0 | ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); |
2852 | 0 | if (dev) { |
2853 | 0 | int i_dev = -1; |
2854 | 0 | for (size_t i = 0; i < devices.size(); i++) { |
2855 | 0 | if (devices[i] == dev) { |
2856 | 0 | i_dev = i; |
2857 | 0 | break; |
2858 | 0 | } |
2859 | 0 | } |
2860 | 0 | if (i_dev != -1) { |
2861 | 0 | mb_dev[i_dev].model += mb.model; |
2862 | 0 | mb_dev[i_dev].context += mb.context; |
2863 | 0 | mb_dev[i_dev].compute += mb.compute; |
2864 | 0 | seen_buffer_types.insert(buft); |
2865 | 0 | continue; |
2866 | 0 | } |
2867 | 0 | } |
2868 | 0 | } |
2869 | | |
2870 | | // print memory breakdown for each device: |
2871 | 0 | for (size_t i = 0; i < devices.size(); i++) { |
2872 | 0 | ggml_backend_dev_t dev = devices[i]; |
2873 | 0 | llama_memory_breakdown_data mb = mb_dev[i]; |
2874 | |
|
2875 | 0 | const std::string name = ggml_backend_dev_name(dev); |
2876 | 0 | std::string desc = ggml_backend_dev_description(dev); |
2877 | 0 | for (const std::string & prefix : desc_prefixes_strip) { |
2878 | 0 | if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) { |
2879 | 0 | desc = desc.substr(prefix.length()); |
2880 | 0 | } |
2881 | 0 | } |
2882 | |
|
2883 | 0 | size_t free, total; |
2884 | 0 | ggml_backend_dev_memory(dev, &free, &total); |
2885 | |
|
2886 | 0 | const size_t self = mb.model + mb.context + mb.compute; |
2887 | 0 | const size_t unaccounted = total - self - free; |
2888 | |
|
2889 | 0 | table_data.push_back({ |
2890 | 0 | template_gpu, |
2891 | 0 | " - " + name + " (" + desc + ")", |
2892 | 0 | std::to_string(total / MiB), |
2893 | 0 | std::to_string(free / MiB), |
2894 | 0 | std::to_string(self / MiB), |
2895 | 0 | std::to_string(mb.model / MiB), |
2896 | 0 | std::to_string(mb.context / MiB), |
2897 | 0 | std::to_string(mb.compute / MiB), |
2898 | 0 | std::to_string(unaccounted / MiB)}); |
2899 | 0 | } |
2900 | | |
2901 | | // print memory breakdown for host: |
2902 | 0 | { |
2903 | 0 | const size_t self = mb_host.model + mb_host.context + mb_host.compute; |
2904 | 0 | table_data.push_back({ |
2905 | 0 | template_other, |
2906 | 0 | " - Host", |
2907 | 0 | "", // total |
2908 | 0 | "", // free |
2909 | 0 | std::to_string(self / MiB), |
2910 | 0 | std::to_string(mb_host.model / MiB), |
2911 | 0 | std::to_string(mb_host.context / MiB), |
2912 | 0 | std::to_string(mb_host.compute / MiB), |
2913 | 0 | ""}); // unaccounted |
2914 | 0 | } |
2915 | | |
2916 | | // print memory breakdown for all remaining buffer types: |
2917 | 0 | for (const auto & buft_mb : memory_breakdown) { |
2918 | 0 | ggml_backend_buffer_type_t buft = buft_mb.first; |
2919 | 0 | const llama_memory_breakdown_data & mb = buft_mb.second; |
2920 | 0 | if (seen_buffer_types.count(buft) == 1) { |
2921 | 0 | continue; |
2922 | 0 | } |
2923 | 0 | const std::string name = ggml_backend_buft_name(buft); |
2924 | 0 | const size_t self = mb.model + mb.context + mb.compute; |
2925 | 0 | table_data.push_back({ |
2926 | 0 | template_other, |
2927 | 0 | " - " + name, |
2928 | 0 | "", // total |
2929 | 0 | "", // free |
2930 | 0 | std::to_string(self / MiB), |
2931 | 0 | std::to_string(mb.model / MiB), |
2932 | 0 | std::to_string(mb.context / MiB), |
2933 | 0 | std::to_string(mb.compute / MiB), |
2934 | 0 | ""}); // unaccounted |
2935 | 0 | seen_buffer_types.insert(buft); |
2936 | 0 | } |
2937 | |
|
2938 | 0 | for (size_t j = 1; j < table_data[0].size(); j++) { |
2939 | 0 | size_t max_len = 0; |
2940 | 0 | for (const auto & td : table_data) { |
2941 | 0 | max_len = std::max(max_len, td[j].length()); |
2942 | 0 | } |
2943 | 0 | for (auto & td : table_data) { |
2944 | 0 | td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' '); |
2945 | 0 | } |
2946 | 0 | } |
2947 | 0 | for (const auto & td : table_data) { |
2948 | 0 | LLAMA_LOG_INFO(td[0].c_str(), |
2949 | 0 | __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(), |
2950 | 0 | td[6].c_str(), td[7].c_str(), td[8].c_str()); |
2951 | 0 | } |
2952 | 0 | } |
2953 | | |
2954 | | // |
2955 | | // training |
2956 | | // |
2957 | | |
2958 | 0 | bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) { |
2959 | 0 | GGML_UNUSED(tensor); |
2960 | 0 | GGML_UNUSED(userdata); |
2961 | 0 | return true; |
2962 | 0 | } |
2963 | | |
2964 | 0 | void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) { |
2965 | 0 | ctx->opt_init(model, lopt_params); |
2966 | 0 | } |
2967 | | |
2968 | | void llama_opt_epoch( |
2969 | | struct llama_context * ctx, |
2970 | | ggml_opt_dataset_t dataset, |
2971 | | ggml_opt_result_t result_train, |
2972 | | ggml_opt_result_t result_eval, |
2973 | | int64_t idata_split, |
2974 | | ggml_opt_epoch_callback callback_train, |
2975 | 0 | ggml_opt_epoch_callback callback_eval) { |
2976 | 0 | ctx->opt_epoch( |
2977 | 0 | dataset, |
2978 | 0 | result_train, |
2979 | 0 | result_eval, |
2980 | 0 | idata_split, |
2981 | 0 | callback_train, |
2982 | 0 | callback_eval); |
2983 | 0 | } |