/src/llama.cpp/src/llama-context.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | #include "llama.h" |
4 | | #include "llama-cparams.h" |
5 | | #include "llama-graph.h" |
6 | | #include "llama-adapter.h" |
7 | | |
8 | | #include "ggml-cpp.h" |
9 | | #include "ggml-opt.h" |
10 | | |
11 | | #include <map> |
12 | | #include <vector> |
13 | | |
14 | | struct llama_model; |
15 | | class llama_batch_allocr; |
16 | | |
17 | | class llama_io_read_i; |
18 | | class llama_io_write_i; |
19 | | |
20 | | // "memory" as in abstract memory for the context |
21 | | struct llama_memory_i; |
22 | | struct llama_memory_context_i; |
23 | | |
24 | | // "memory" as in physical memory for a buffer type, in bytes |
25 | | struct llama_memory_breakdown_data { |
26 | | size_t model = 0; // memory allocated for the model |
27 | | size_t context = 0; // memory allocated for the context |
28 | | size_t compute = 0; // memory allocated for temporary compute buffers |
29 | | |
30 | 0 | size_t total() const { |
31 | 0 | return model + context + compute; |
32 | 0 | } |
33 | | }; |
34 | | |
35 | | struct llama_context { |
36 | | // init scheduler and compute buffers, reserve worst-case graphs |
37 | | llama_context( |
38 | | const llama_model & model, |
39 | | llama_context_params params); |
40 | | |
41 | | ~llama_context(); |
42 | | |
43 | | void synchronize(); |
44 | | |
45 | | const llama_model & get_model() const; |
46 | | const llama_cparams & get_cparams() const; |
47 | | |
48 | | ggml_backend_sched_t get_sched() const; |
49 | | |
50 | | uint32_t n_ctx() const; |
51 | | uint32_t n_ctx_seq() const; |
52 | | uint32_t n_batch() const; |
53 | | uint32_t n_ubatch() const; |
54 | | uint32_t n_seq_max() const; |
55 | | |
56 | | uint32_t n_threads() const; |
57 | | uint32_t n_threads_batch() const; |
58 | | |
59 | | llama_memory_t get_memory() const; |
60 | | |
61 | | // return true if the memory was updated |
62 | | bool memory_update(bool optimize); |
63 | | |
64 | | enum llama_pooling_type pooling_type() const; |
65 | | |
66 | | float * get_logits(); |
67 | | float * get_logits_ith(int32_t i); |
68 | | |
69 | | float * get_embeddings(); |
70 | | float * get_embeddings_ith(int32_t i); |
71 | | float * get_embeddings_seq(llama_seq_id seq_id); |
72 | | |
73 | | void attach_threadpool( |
74 | | ggml_threadpool_t threadpool, |
75 | | ggml_threadpool_t threadpool_batch); |
76 | | |
77 | | void detach_threadpool(); |
78 | | |
79 | | void set_n_threads(int32_t n_threads, int32_t n_threads_batch); |
80 | | |
81 | | void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); |
82 | | |
83 | | void set_embeddings (bool value); |
84 | | void set_causal_attn(bool value); |
85 | | void set_warmup(bool value); |
86 | | |
87 | | void set_adapter_lora( |
88 | | llama_adapter_lora * adapter, |
89 | | float scale); |
90 | | |
91 | | bool rm_adapter_lora( |
92 | | llama_adapter_lora * adapter); |
93 | | |
94 | | void clear_adapter_lora(); |
95 | | |
96 | | bool apply_adapter_cvec( |
97 | | const float * data, |
98 | | size_t len, |
99 | | int32_t n_embd, |
100 | | int32_t il_start, |
101 | | int32_t il_end); |
102 | | |
103 | | // process a single ubatch with a specific graph type |
104 | | // if memory_context is provided, it will be applied first to the context's memory |
105 | | // ret contains the status of the graph computation |
106 | | // returns nullptr only if ret != GGML_STATUS_SUCCESS |
107 | | llm_graph_result * process_ubatch( |
108 | | const llama_ubatch & ubatch, |
109 | | llm_graph_type gtype, |
110 | | llama_memory_context_i * mctx, |
111 | | ggml_status & ret); |
112 | | |
113 | | int encode(const llama_batch & batch_inp); |
114 | | int decode(const llama_batch & batch_inp); |
115 | | |
116 | | // |
117 | | // state save/load |
118 | | // |
119 | | |
120 | | size_t state_get_size(); |
121 | | size_t state_get_data( uint8_t * dst, size_t size); |
122 | | size_t state_set_data(const uint8_t * src, size_t size); |
123 | | |
124 | | size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags); |
125 | | size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags); |
126 | | size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags); |
127 | | |
128 | | bool state_load_file( |
129 | | const char * filepath, |
130 | | llama_token * tokens_out, |
131 | | size_t n_token_capacity, |
132 | | size_t * n_token_count_out); |
133 | | |
134 | | bool state_save_file( |
135 | | const char * filepath, |
136 | | const llama_token * tokens, |
137 | | size_t n_token_count); |
138 | | |
139 | | size_t state_seq_load_file( |
140 | | llama_seq_id seq_id, |
141 | | const char * filepath, |
142 | | llama_token * tokens_out, |
143 | | size_t n_token_capacity, |
144 | | size_t * n_token_count_out); |
145 | | |
146 | | size_t state_seq_save_file( |
147 | | llama_seq_id seq_id, |
148 | | const char * filepath, |
149 | | const llama_token * tokens, |
150 | | size_t n_token_count); |
151 | | |
152 | | // |
153 | | // perf |
154 | | // |
155 | | |
156 | | llama_perf_context_data perf_get_data() const; |
157 | | void perf_reset(); |
158 | | |
159 | | std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown() const; |
160 | | |
161 | | // |
162 | | // training |
163 | | // |
164 | | |
165 | | void opt_init(struct llama_model * model, struct llama_opt_params lopt_params); |
166 | | |
167 | | // TODO: more flexible combinations of logical/physical batch size and context size |
168 | | void opt_epoch( |
169 | | ggml_opt_dataset_t dataset, |
170 | | ggml_opt_result_t result_train, |
171 | | ggml_opt_result_t result_eval, |
172 | | int64_t idata_split, |
173 | | ggml_opt_epoch_callback callback_train, |
174 | | ggml_opt_epoch_callback callback_eval); |
175 | | |
176 | | void opt_epoch_iter( |
177 | | ggml_opt_dataset_t dataset, |
178 | | ggml_opt_result_t result, |
179 | | const std::vector<llama_token> & tokens, |
180 | | const std::vector<llama_token> & labels_sparse, |
181 | | llama_batch & batch, |
182 | | ggml_opt_epoch_callback callback, |
183 | | bool train, |
184 | | int64_t idata_in_loop, |
185 | | int64_t ndata_in_loop, |
186 | | int64_t t_loop_start); |
187 | | |
188 | | private: |
189 | | // |
190 | | // output |
191 | | // |
192 | | |
193 | | // Make sure enough space is available for outputs. |
194 | | // Returns max number of outputs for which space was reserved. |
195 | | uint32_t output_reserve(int32_t n_outputs); |
196 | | |
197 | | void output_reorder(); |
198 | | |
199 | | // |
200 | | // graph |
201 | | // |
202 | | |
203 | | public: |
204 | | uint32_t graph_max_nodes(uint32_t n_tokens) const; |
205 | | |
206 | | // can reuse the llm_graph_result instance of the context (for example to update a memory module) |
207 | | llm_graph_result * get_gf_res_reserve() const; |
208 | | |
209 | | // returns the result of ggml_backend_sched_graph_compute_async execution |
210 | | ggml_status graph_compute(ggml_cgraph * gf, bool batched); |
211 | | |
212 | | // reserve a graph with a dummy ubatch of the specified size |
213 | | ggml_cgraph * graph_reserve( |
214 | | uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr); |
215 | | |
216 | | private: |
217 | | llm_graph_params graph_params( |
218 | | llm_graph_result * res, |
219 | | const llama_ubatch & ubatch, |
220 | | const llama_memory_context_i * mctx, |
221 | | llm_graph_type gtype) const; |
222 | | |
223 | | llm_graph_cb graph_get_cb() const; |
224 | | |
225 | | // TODO: read/write lora adapters and cvec |
226 | | size_t state_write_data(llama_io_write_i & io); |
227 | | size_t state_read_data (llama_io_read_i & io); |
228 | | |
229 | | size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags); |
230 | | size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags); |
231 | | |
232 | | // |
233 | | // members |
234 | | // |
235 | | |
236 | | const llama_model & model; |
237 | | |
238 | | llama_cparams cparams; |
239 | | llama_adapter_cvec cvec; |
240 | | llama_adapter_loras loras; |
241 | | |
242 | | llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably |
243 | | |
244 | | std::unique_ptr<llama_memory_i> memory; |
245 | | |
246 | | // decode output (2-dimensional array: [n_outputs][n_vocab]) |
247 | | size_t logits_size = 0; // capacity (of floats) for logits |
248 | | float * logits = nullptr; |
249 | | |
250 | | // embeddings output (2-dimensional array: [n_outputs][n_embd]) |
251 | | // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE |
252 | | size_t embd_size = 0; // capacity (of floats) for embeddings |
253 | | float * embd = nullptr; |
254 | | |
255 | | // sequence embeddings output (map of [n_embd] vectors) |
256 | | // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE |
257 | | std::map<llama_seq_id, std::vector<float>> embd_seq; |
258 | | |
259 | | // reuse the batch_allocr to avoid unnecessary memory allocations |
260 | | std::unique_ptr<llama_batch_allocr> balloc; |
261 | | |
262 | | uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch |
263 | | |
264 | | std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers |
265 | | |
266 | | struct swap_info { |
267 | | uint32_t i0; |
268 | | uint32_t i1; |
269 | | }; |
270 | | |
271 | | std::vector<swap_info> output_swaps; |
272 | | |
273 | | ggml_backend_sched_ptr sched; |
274 | | |
275 | | ggml_backend_t backend_cpu = nullptr; |
276 | | std::vector<ggml_backend_ptr> backends; |
277 | | |
278 | | // training |
279 | | ggml_opt_context_t opt_ctx = nullptr; |
280 | | |
281 | | ggml_threadpool_t threadpool = nullptr; |
282 | | ggml_threadpool_t threadpool_batch = nullptr; |
283 | | |
284 | | ggml_abort_callback abort_callback = nullptr; |
285 | | void * abort_callback_data = nullptr; |
286 | | |
287 | | std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns; |
288 | | |
289 | | // pointers and buffer types used for the compute buffer of each backend |
290 | | std::vector<ggml_backend_t> backend_ptrs; |
291 | | std::vector<ggml_backend_buffer_type_t> backend_buft; |
292 | | std::vector<size_t> backend_buf_exp_size; // expected buffer sizes |
293 | | |
294 | | llm_graph_result_ptr gf_res_prev; |
295 | | llm_graph_result_ptr gf_res_reserve; |
296 | | |
297 | | // host buffer for the model output (logits and embeddings) |
298 | | ggml_backend_buffer_ptr buf_output; |
299 | | |
300 | | bool has_evaluated_once = false; |
301 | | |
302 | | // env: LLAMA_GRAPH_REUSE_DISABLE |
303 | | bool graph_reuse_disable = false; |
304 | | |
305 | | // perf |
306 | | mutable int64_t t_start_us = 0; |
307 | | mutable int64_t t_load_us = 0; |
308 | | mutable int64_t t_p_eval_us = 0; |
309 | | mutable int64_t t_eval_us = 0; |
310 | | |
311 | | mutable int64_t t_compute_start_us = 0; |
312 | | mutable int64_t n_queued_tokens = 0; |
313 | | |
314 | | mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) |
315 | | mutable int32_t n_eval = 0; // number of eval calls |
316 | | |
317 | | mutable int32_t n_reused = 0; // number of times the previous graph was reused |
318 | | }; |