/src/llama.cpp/src/llama-context.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | #include "llama.h" |
4 | | #include "llama-cparams.h" |
5 | | #include "llama-graph.h" |
6 | | #include "llama-adapter.h" |
7 | | #include "llama-impl.h" |
8 | | |
9 | | #include "ggml-cpp.h" |
10 | | #include "ggml-opt.h" |
11 | | |
12 | | #include <map> |
13 | | #include <vector> |
14 | | |
15 | | struct llama_model; |
16 | | class llama_batch_allocr; |
17 | | |
18 | | class llama_io_read_i; |
19 | | class llama_io_write_i; |
20 | | |
21 | | // "memory" as in abstract memory for the context |
22 | | struct llama_memory_i; |
23 | | struct llama_memory_context_i; |
24 | | |
25 | | // "memory" as in physical memory for a buffer type, in bytes |
26 | | struct llama_memory_breakdown_data { |
27 | | size_t model = 0; // memory allocated for the model |
28 | | size_t context = 0; // memory allocated for the context |
29 | | size_t compute = 0; // memory allocated for temporary compute buffers |
30 | | |
31 | 0 | size_t total() const { |
32 | 0 | return model + context + compute; |
33 | 0 | } |
34 | | }; |
35 | | |
36 | | struct llama_context { |
37 | | // init scheduler and compute buffers, reserve worst-case graphs |
38 | | llama_context( |
39 | | const llama_model & model, |
40 | | llama_context_params params); |
41 | | |
42 | | ~llama_context(); |
43 | | |
44 | | // reserve a new backend scheduler (if needed) |
45 | | // for example, when: |
46 | | // - changing loras |
47 | | // - changing samplers |
48 | | // - changing attention type |
49 | | // - etc. |
50 | | void sched_reserve(); |
51 | | |
52 | | void synchronize(); |
53 | | |
54 | | const llama_model & get_model() const; |
55 | | const llama_cparams & get_cparams() const; |
56 | | |
57 | | ggml_backend_sched_t get_sched() const; |
58 | | |
59 | | uint32_t n_ctx() const; |
60 | | uint32_t n_ctx_seq() const; |
61 | | uint32_t n_batch() const; |
62 | | uint32_t n_ubatch() const; |
63 | | uint32_t n_seq_max() const; |
64 | | |
65 | | uint32_t n_threads() const; |
66 | | uint32_t n_threads_batch() const; |
67 | | |
68 | | llama_memory_t get_memory() const; |
69 | | |
70 | | // return true if the memory was updated |
71 | | bool memory_update(bool optimize); |
72 | | |
73 | | enum llama_pooling_type pooling_type() const; |
74 | | |
75 | | float * get_logits(); |
76 | | float * get_logits_ith(int32_t i); |
77 | | |
78 | | float * get_embeddings(); |
79 | | float * get_embeddings_ith(int32_t i); |
80 | | float * get_embeddings_seq(llama_seq_id seq_id); |
81 | | |
82 | | llama_token * get_sampled_tokens() const; |
83 | | llama_token get_sampled_token_ith(int32_t idx); |
84 | | |
85 | | float * get_sampled_logits_ith(int32_t idx); |
86 | | size_t get_sampled_logits_count(int32_t idx); |
87 | | |
88 | | float * get_sampled_probs_ith(int32_t idx); |
89 | | size_t get_sampled_probs_count(int32_t idx); |
90 | | |
91 | | const llama_token * get_sampled_candidates_ith(int32_t idx); |
92 | | size_t get_sampled_candidates_count(int32_t idx); |
93 | | |
94 | | void attach_threadpool( |
95 | | ggml_threadpool_t threadpool, |
96 | | ggml_threadpool_t threadpool_batch); |
97 | | |
98 | | void detach_threadpool(); |
99 | | |
100 | | void set_n_threads(int32_t n_threads, int32_t n_threads_batch); |
101 | | |
102 | | void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); |
103 | | |
104 | | void set_embeddings (bool value); |
105 | | void set_causal_attn(bool value); |
106 | | void set_warmup(bool value); |
107 | | |
108 | | void set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); |
109 | | |
110 | | bool adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); |
111 | | |
112 | | bool set_adapter_cvec( |
113 | | const float * data, |
114 | | size_t len, |
115 | | int32_t n_embd, |
116 | | int32_t il_start, |
117 | | int32_t il_end); |
118 | | |
119 | | // process a single ubatch with a specific graph type |
120 | | // if memory_context is provided, it will be applied first to the context's memory |
121 | | // ret contains the status of the graph computation |
122 | | // returns nullptr only if ret != GGML_STATUS_SUCCESS |
123 | | llm_graph_result * process_ubatch( |
124 | | const llama_ubatch & ubatch, |
125 | | llm_graph_type gtype, |
126 | | llama_memory_context_i * mctx, |
127 | | ggml_status & ret); |
128 | | |
129 | | int encode(const llama_batch & batch_inp); |
130 | | int decode(const llama_batch & batch_inp); |
131 | | |
132 | | // |
133 | | // state save/load |
134 | | // |
135 | | |
136 | | size_t state_get_size(); |
137 | | size_t state_get_data( uint8_t * dst, size_t size); |
138 | | size_t state_set_data(const uint8_t * src, size_t size); |
139 | | |
140 | | size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags); |
141 | | size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags); |
142 | | size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags); |
143 | | |
144 | | bool state_load_file( |
145 | | const char * filepath, |
146 | | llama_token * tokens_out, |
147 | | size_t n_token_capacity, |
148 | | size_t * n_token_count_out); |
149 | | |
150 | | bool state_save_file( |
151 | | const char * filepath, |
152 | | const llama_token * tokens, |
153 | | size_t n_token_count); |
154 | | |
155 | | size_t state_seq_load_file( |
156 | | llama_seq_id seq_id, |
157 | | const char * filepath, |
158 | | llama_token * tokens_out, |
159 | | size_t n_token_capacity, |
160 | | size_t * n_token_count_out); |
161 | | |
162 | | size_t state_seq_save_file( |
163 | | llama_seq_id seq_id, |
164 | | const char * filepath, |
165 | | const llama_token * tokens, |
166 | | size_t n_token_count); |
167 | | |
168 | | // |
169 | | // perf |
170 | | // |
171 | | |
172 | | llama_perf_context_data perf_get_data() const; |
173 | | void perf_reset(); |
174 | | |
175 | | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown() const; |
176 | | |
177 | | // |
178 | | // training |
179 | | // |
180 | | |
181 | | void opt_init(struct llama_model * model, struct llama_opt_params lopt_params); |
182 | | |
183 | | // TODO: more flexible combinations of logical/physical batch size and context size |
184 | | void opt_epoch( |
185 | | ggml_opt_dataset_t dataset, |
186 | | ggml_opt_result_t result_train, |
187 | | ggml_opt_result_t result_eval, |
188 | | int64_t idata_split, |
189 | | ggml_opt_epoch_callback callback_train, |
190 | | ggml_opt_epoch_callback callback_eval); |
191 | | |
192 | | void opt_epoch_iter( |
193 | | ggml_opt_dataset_t dataset, |
194 | | ggml_opt_result_t result, |
195 | | const std::vector<llama_token> & tokens, |
196 | | const std::vector<llama_token> & labels_sparse, |
197 | | llama_batch & batch, |
198 | | ggml_opt_epoch_callback callback, |
199 | | bool train, |
200 | | int64_t idata_in_loop, |
201 | | int64_t ndata_in_loop, |
202 | | int64_t t_loop_start); |
203 | | |
204 | | private: |
205 | | // |
206 | | // output |
207 | | // |
208 | | |
209 | | // Make sure enough space is available for outputs. |
210 | | // Returns max number of outputs for which space was reserved. |
211 | | uint32_t output_reserve(int32_t n_outputs); |
212 | | |
213 | | void output_reorder(); |
214 | | |
215 | | // map the output row index `i` to batch index |
216 | | int64_t output_resolve_row(int32_t i) const; |
217 | | |
218 | | // |
219 | | // graph |
220 | | // |
221 | | |
222 | | public: |
223 | | uint32_t graph_max_nodes(uint32_t n_tokens) const; |
224 | | |
225 | | // can reuse the llm_graph_result instance of the context (for example to update a memory module) |
226 | | llm_graph_result * get_gf_res_reserve() const; |
227 | | |
228 | | // returns the result of ggml_backend_sched_graph_compute_async execution |
229 | | ggml_status graph_compute(ggml_cgraph * gf, bool batched); |
230 | | |
231 | | // reserve a graph with a dummy ubatch of the specified size |
232 | | ggml_cgraph * graph_reserve( |
233 | | uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr); |
234 | | |
235 | | bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler); |
236 | | |
237 | | private: |
238 | | llm_graph_params graph_params( |
239 | | llm_graph_result * res, |
240 | | const llama_ubatch & ubatch, |
241 | | const llama_memory_context_i * mctx, |
242 | | llm_graph_type gtype) const; |
243 | | |
244 | | llm_graph_cb graph_get_cb() const; |
245 | | |
246 | | // TODO: read/write lora adapters and cvec |
247 | | size_t state_write_data(llama_io_write_i & io); |
248 | | size_t state_read_data (llama_io_read_i & io); |
249 | | |
250 | | size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags); |
251 | | size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags); |
252 | | |
253 | | // |
254 | | // members |
255 | | // |
256 | | |
257 | | const llama_model & model; |
258 | | |
259 | | llama_cparams cparams; |
260 | | |
261 | | llama_adapter_cvec_ptr cvec; |
262 | | llama_adapter_loras_ptr loras; |
263 | | |
264 | | llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably |
265 | | |
266 | | std::unique_ptr<llama_memory_i> memory; |
267 | | |
268 | | // decode output (2-dimensional array: [n_outputs][n_vocab]) |
269 | | buffer_view<float> logits = {nullptr, 0}; |
270 | | |
271 | | // embeddings output (2-dimensional array: [n_outputs][n_embd]) |
272 | | // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE |
273 | | buffer_view<float> embd = {nullptr, 0}; |
274 | | |
275 | | struct sampling_info { |
276 | | // !samplers.empty() to check if any samplers are active |
277 | | std::map<llama_seq_id, llama_sampler *> samplers; |
278 | | |
279 | | buffer_view<float> logits = {nullptr, 0}; |
280 | | buffer_view<llama_token> sampled = {nullptr, 0}; |
281 | | buffer_view<float> probs = {nullptr, 0}; |
282 | | buffer_view<llama_token> candidates = {nullptr, 0}; |
283 | | |
284 | | std::vector<uint32_t> logits_count; |
285 | | std::vector<uint32_t> probs_count; |
286 | | std::vector<uint32_t> candidates_count; |
287 | | |
288 | | // optimization |
289 | | std::vector<llama_token> token_ids_full_vocab; |
290 | | }; |
291 | | |
292 | | sampling_info sampling; |
293 | | |
294 | | // sequence embeddings output (map of [n_embd] vectors) |
295 | | // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE |
296 | | std::map<llama_seq_id, std::vector<float>> embd_seq; |
297 | | |
298 | | // reuse the batch_allocr to avoid unnecessary memory allocations |
299 | | std::unique_ptr<llama_batch_allocr> balloc; |
300 | | |
301 | | uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch |
302 | | |
303 | | std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers |
304 | | |
305 | | struct swap_info { |
306 | | uint32_t i0; |
307 | | uint32_t i1; |
308 | | }; |
309 | | |
310 | | std::vector<swap_info> output_swaps; |
311 | | |
312 | | ggml_backend_sched_ptr sched; |
313 | | |
314 | | bool sched_need_reserve = true; |
315 | | |
316 | | ggml_backend_t backend_cpu = nullptr; |
317 | | std::vector<ggml_backend_ptr> backends; |
318 | | |
319 | | // training |
320 | | ggml_opt_context_t opt_ctx = nullptr; |
321 | | |
322 | | ggml_threadpool_t threadpool = nullptr; |
323 | | ggml_threadpool_t threadpool_batch = nullptr; |
324 | | |
325 | | ggml_abort_callback abort_callback = nullptr; |
326 | | void * abort_callback_data = nullptr; |
327 | | |
328 | | std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns; |
329 | | |
330 | | // pointers and buffer types used for the compute buffer of each backend |
331 | | std::vector<ggml_backend_t> backend_ptrs; |
332 | | std::vector<ggml_backend_buffer_type_t> backend_buft; |
333 | | std::vector<size_t> backend_buf_exp_size; // expected buffer sizes |
334 | | |
335 | | llm_graph_result_ptr gf_res_prev; |
336 | | llm_graph_result_ptr gf_res_reserve; |
337 | | |
338 | | // host buffer for the model output (logits and embeddings) |
339 | | ggml_backend_buffer_ptr buf_output; |
340 | | |
341 | | bool has_evaluated_once = false; |
342 | | |
343 | | // env: LLAMA_GRAPH_REUSE_DISABLE |
344 | | bool graph_reuse_disable = false; |
345 | | |
346 | | // perf |
347 | | mutable int64_t t_start_us = 0; |
348 | | mutable int64_t t_load_us = 0; |
349 | | mutable int64_t t_p_eval_us = 0; |
350 | | mutable int64_t t_eval_us = 0; |
351 | | |
352 | | mutable int64_t t_compute_start_us = 0; |
353 | | mutable int64_t n_queued_tokens = 0; |
354 | | |
355 | | mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) |
356 | | mutable int32_t n_eval = 0; // number of eval calls |
357 | | |
358 | | mutable int32_t n_reused = 0; // number of times the previous graph was reused |
359 | | }; |