/src/llama.cpp/src/llama-context.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | #include "llama.h" |
4 | | #include "llama-cparams.h" |
5 | | #include "llama-graph.h" |
6 | | #include "llama-adapter.h" |
7 | | |
8 | | #include "ggml-cpp.h" |
9 | | #include "ggml-opt.h" |
10 | | |
11 | | #include <map> |
12 | | #include <vector> |
13 | | |
14 | | struct llama_model; |
15 | | class llama_batch_allocr; |
16 | | |
17 | | class llama_io_read_i; |
18 | | class llama_io_write_i; |
19 | | |
20 | | // "memory" as in abstract memory for the context |
21 | | struct llama_memory_i; |
22 | | struct llama_memory_context_i; |
23 | | |
24 | | // "memory" as in physical memory for a buffer type, in bytes |
25 | | struct llama_memory_breakdown_data { |
26 | | size_t model = 0; // memory allocated for the model |
27 | | size_t context = 0; // memory allocated for the context |
28 | | size_t compute = 0; // memory allocated for temporary compute buffers |
29 | | |
30 | 0 | size_t total() const { |
31 | 0 | return model + context + compute; |
32 | 0 | } |
33 | | }; |
34 | | |
35 | | struct llama_context { |
36 | | // init scheduler and compute buffers, reserve worst-case graphs |
37 | | llama_context( |
38 | | const llama_model & model, |
39 | | llama_context_params params); |
40 | | |
41 | | ~llama_context(); |
42 | | |
43 | | // reserve a new backend scheduler (if needed) |
44 | | // for example, when: |
45 | | // - changing loras |
46 | | // - changing samplers |
47 | | // - changing attention type |
48 | | // - etc. |
49 | | void sched_reserve(); |
50 | | |
51 | | void synchronize(); |
52 | | |
53 | | const llama_model & get_model() const; |
54 | | const llama_cparams & get_cparams() const; |
55 | | |
56 | | ggml_backend_sched_t get_sched() const; |
57 | | |
58 | | uint32_t n_ctx() const; |
59 | | uint32_t n_ctx_seq() const; |
60 | | uint32_t n_batch() const; |
61 | | uint32_t n_ubatch() const; |
62 | | uint32_t n_seq_max() const; |
63 | | |
64 | | uint32_t n_threads() const; |
65 | | uint32_t n_threads_batch() const; |
66 | | |
67 | | llama_memory_t get_memory() const; |
68 | | |
69 | | // return true if the memory was updated |
70 | | bool memory_update(bool optimize); |
71 | | |
72 | | enum llama_pooling_type pooling_type() const; |
73 | | |
74 | | float * get_logits(); |
75 | | float * get_logits_ith(int32_t i); |
76 | | |
77 | | float * get_embeddings(); |
78 | | float * get_embeddings_ith(int32_t i); |
79 | | float * get_embeddings_seq(llama_seq_id seq_id); |
80 | | |
81 | | llama_token * get_sampled_tokens() const; |
82 | | llama_token get_sampled_token_ith(int32_t idx); |
83 | | |
84 | | float * get_sampled_logits_ith(int32_t idx); |
85 | | size_t get_sampled_logits_count(int32_t idx); |
86 | | |
87 | | float * get_sampled_probs_ith(int32_t idx); |
88 | | size_t get_sampled_probs_count(int32_t idx); |
89 | | |
90 | | const llama_token * get_sampled_candidates_ith(int32_t idx); |
91 | | size_t get_sampled_candidates_count(int32_t idx); |
92 | | |
93 | | void attach_threadpool( |
94 | | ggml_threadpool_t threadpool, |
95 | | ggml_threadpool_t threadpool_batch); |
96 | | |
97 | | void detach_threadpool(); |
98 | | |
99 | | void set_n_threads(int32_t n_threads, int32_t n_threads_batch); |
100 | | |
101 | | void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); |
102 | | |
103 | | void set_embeddings (bool value); |
104 | | void set_causal_attn(bool value); |
105 | | void set_warmup(bool value); |
106 | | |
107 | | void set_adapter_lora( |
108 | | llama_adapter_lora * adapter, |
109 | | float scale); |
110 | | |
111 | | bool rm_adapter_lora( |
112 | | llama_adapter_lora * adapter); |
113 | | |
114 | | void clear_adapter_lora(); |
115 | | |
116 | | bool apply_adapter_cvec( |
117 | | const float * data, |
118 | | size_t len, |
119 | | int32_t n_embd, |
120 | | int32_t il_start, |
121 | | int32_t il_end); |
122 | | |
123 | | // process a single ubatch with a specific graph type |
124 | | // if memory_context is provided, it will be applied first to the context's memory |
125 | | // ret contains the status of the graph computation |
126 | | // returns nullptr only if ret != GGML_STATUS_SUCCESS |
127 | | llm_graph_result * process_ubatch( |
128 | | const llama_ubatch & ubatch, |
129 | | llm_graph_type gtype, |
130 | | llama_memory_context_i * mctx, |
131 | | ggml_status & ret); |
132 | | |
133 | | int encode(const llama_batch & batch_inp); |
134 | | int decode(const llama_batch & batch_inp); |
135 | | |
136 | | // |
137 | | // state save/load |
138 | | // |
139 | | |
140 | | size_t state_get_size(); |
141 | | size_t state_get_data( uint8_t * dst, size_t size); |
142 | | size_t state_set_data(const uint8_t * src, size_t size); |
143 | | |
144 | | size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags); |
145 | | size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags); |
146 | | size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags); |
147 | | |
148 | | bool state_load_file( |
149 | | const char * filepath, |
150 | | llama_token * tokens_out, |
151 | | size_t n_token_capacity, |
152 | | size_t * n_token_count_out); |
153 | | |
154 | | bool state_save_file( |
155 | | const char * filepath, |
156 | | const llama_token * tokens, |
157 | | size_t n_token_count); |
158 | | |
159 | | size_t state_seq_load_file( |
160 | | llama_seq_id seq_id, |
161 | | const char * filepath, |
162 | | llama_token * tokens_out, |
163 | | size_t n_token_capacity, |
164 | | size_t * n_token_count_out); |
165 | | |
166 | | size_t state_seq_save_file( |
167 | | llama_seq_id seq_id, |
168 | | const char * filepath, |
169 | | const llama_token * tokens, |
170 | | size_t n_token_count); |
171 | | |
172 | | // |
173 | | // perf |
174 | | // |
175 | | |
176 | | llama_perf_context_data perf_get_data() const; |
177 | | void perf_reset(); |
178 | | |
179 | | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown() const; |
180 | | |
181 | | // |
182 | | // training |
183 | | // |
184 | | |
185 | | void opt_init(struct llama_model * model, struct llama_opt_params lopt_params); |
186 | | |
187 | | // TODO: more flexible combinations of logical/physical batch size and context size |
188 | | void opt_epoch( |
189 | | ggml_opt_dataset_t dataset, |
190 | | ggml_opt_result_t result_train, |
191 | | ggml_opt_result_t result_eval, |
192 | | int64_t idata_split, |
193 | | ggml_opt_epoch_callback callback_train, |
194 | | ggml_opt_epoch_callback callback_eval); |
195 | | |
196 | | void opt_epoch_iter( |
197 | | ggml_opt_dataset_t dataset, |
198 | | ggml_opt_result_t result, |
199 | | const std::vector<llama_token> & tokens, |
200 | | const std::vector<llama_token> & labels_sparse, |
201 | | llama_batch & batch, |
202 | | ggml_opt_epoch_callback callback, |
203 | | bool train, |
204 | | int64_t idata_in_loop, |
205 | | int64_t ndata_in_loop, |
206 | | int64_t t_loop_start); |
207 | | |
208 | | private: |
209 | | // |
210 | | // output |
211 | | // |
212 | | |
213 | | // Make sure enough space is available for outputs. |
214 | | // Returns max number of outputs for which space was reserved. |
215 | | uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch); |
216 | | |
217 | | void output_reorder(); |
218 | | |
219 | | // map the output row index `i` to batch index |
220 | | int64_t output_resolve_row(int32_t i) const; |
221 | | |
222 | | // |
223 | | // graph |
224 | | // |
225 | | |
226 | | public: |
227 | | uint32_t graph_max_nodes(uint32_t n_tokens) const; |
228 | | |
229 | | // can reuse the llm_graph_result instance of the context (for example to update a memory module) |
230 | | llm_graph_result * get_gf_res_reserve() const; |
231 | | |
232 | | // returns the result of ggml_backend_sched_graph_compute_async execution |
233 | | ggml_status graph_compute(ggml_cgraph * gf, bool batched); |
234 | | |
235 | | // reserve a graph with a dummy ubatch of the specified size |
236 | | ggml_cgraph * graph_reserve( |
237 | | uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr); |
238 | | |
239 | | bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler); |
240 | | |
241 | | private: |
242 | | llm_graph_params graph_params( |
243 | | llm_graph_result * res, |
244 | | const llama_ubatch & ubatch, |
245 | | const llama_memory_context_i * mctx, |
246 | | llm_graph_type gtype) const; |
247 | | |
248 | | llm_graph_cb graph_get_cb() const; |
249 | | |
250 | | // TODO: read/write lora adapters and cvec |
251 | | size_t state_write_data(llama_io_write_i & io); |
252 | | size_t state_read_data (llama_io_read_i & io); |
253 | | |
254 | | size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags); |
255 | | size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags); |
256 | | |
257 | | // |
258 | | // members |
259 | | // |
260 | | |
261 | | const llama_model & model; |
262 | | |
263 | | llama_cparams cparams; |
264 | | llama_adapter_cvec cvec; |
265 | | llama_adapter_loras loras; |
266 | | |
267 | | llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably |
268 | | |
269 | | std::unique_ptr<llama_memory_i> memory; |
270 | | |
271 | | // decode output (2-dimensional array: [n_outputs][n_vocab]) |
272 | | size_t logits_size = 0; // capacity (of floats) for logits |
273 | | float * logits = nullptr; |
274 | | |
275 | | // embeddings output (2-dimensional array: [n_outputs][n_embd]) |
276 | | // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE |
277 | | size_t embd_size = 0; // capacity (of floats) for embeddings |
278 | | float * embd = nullptr; |
279 | | |
280 | | // TODO: simplify |
281 | | struct sampling_info { |
282 | | std::map<llama_seq_id, llama_sampler *> samplers; |
283 | | |
284 | | float * logits = nullptr; |
285 | | size_t logits_size = 0; |
286 | | |
287 | | llama_token * sampled = nullptr; |
288 | | size_t sampled_size = 0; |
289 | | |
290 | | float * probs = nullptr; |
291 | | size_t probs_size = 0; |
292 | | |
293 | | llama_token * candidates = nullptr; |
294 | | size_t candidates_size = 0; |
295 | | |
296 | | std::vector<uint32_t> logits_count; |
297 | | std::vector<uint32_t> probs_count; |
298 | | std::vector<uint32_t> candidates_count; |
299 | | |
300 | | std::vector<llama_token> token_ids_full_vocab; |
301 | | }; |
302 | | |
303 | | sampling_info sampling; |
304 | | |
305 | | // sequence embeddings output (map of [n_embd] vectors) |
306 | | // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE |
307 | | std::map<llama_seq_id, std::vector<float>> embd_seq; |
308 | | |
309 | | // reuse the batch_allocr to avoid unnecessary memory allocations |
310 | | std::unique_ptr<llama_batch_allocr> balloc; |
311 | | |
312 | | uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch |
313 | | |
314 | | std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers |
315 | | |
316 | | struct swap_info { |
317 | | uint32_t i0; |
318 | | uint32_t i1; |
319 | | }; |
320 | | |
321 | | std::vector<swap_info> output_swaps; |
322 | | |
323 | | ggml_backend_sched_ptr sched; |
324 | | |
325 | | bool sched_need_reserve = true; |
326 | | |
327 | | ggml_backend_t backend_cpu = nullptr; |
328 | | std::vector<ggml_backend_ptr> backends; |
329 | | |
330 | | // training |
331 | | ggml_opt_context_t opt_ctx = nullptr; |
332 | | |
333 | | ggml_threadpool_t threadpool = nullptr; |
334 | | ggml_threadpool_t threadpool_batch = nullptr; |
335 | | |
336 | | ggml_abort_callback abort_callback = nullptr; |
337 | | void * abort_callback_data = nullptr; |
338 | | |
339 | | std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns; |
340 | | |
341 | | // pointers and buffer types used for the compute buffer of each backend |
342 | | std::vector<ggml_backend_t> backend_ptrs; |
343 | | std::vector<ggml_backend_buffer_type_t> backend_buft; |
344 | | std::vector<size_t> backend_buf_exp_size; // expected buffer sizes |
345 | | |
346 | | llm_graph_result_ptr gf_res_prev; |
347 | | llm_graph_result_ptr gf_res_reserve; |
348 | | |
349 | | // host buffer for the model output (logits and embeddings) |
350 | | ggml_backend_buffer_ptr buf_output; |
351 | | |
352 | | bool has_evaluated_once = false; |
353 | | |
354 | | // env: LLAMA_GRAPH_REUSE_DISABLE |
355 | | bool graph_reuse_disable = false; |
356 | | |
357 | | // perf |
358 | | mutable int64_t t_start_us = 0; |
359 | | mutable int64_t t_load_us = 0; |
360 | | mutable int64_t t_p_eval_us = 0; |
361 | | mutable int64_t t_eval_us = 0; |
362 | | |
363 | | mutable int64_t t_compute_start_us = 0; |
364 | | mutable int64_t n_queued_tokens = 0; |
365 | | |
366 | | mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) |
367 | | mutable int32_t n_eval = 0; // number of eval calls |
368 | | |
369 | | mutable int32_t n_reused = 0; // number of times the previous graph was reused |
370 | | }; |