/src/llama.cpp/src/llama-memory.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | #include "llama.h" |
4 | | |
5 | | #include <map> |
6 | | #include <memory> |
7 | | #include <functional> |
8 | | |
9 | | struct llama_ubatch; |
10 | | |
11 | | class llama_batch_allocr; |
12 | | |
13 | | class llama_io_write_i; |
14 | | class llama_io_read_i; |
15 | | |
16 | | struct llama_memory_params { |
17 | | // kv cache |
18 | | ggml_type type_k; |
19 | | ggml_type type_v; |
20 | | |
21 | | // use full-size SWA cache |
22 | | bool swa_full; |
23 | | }; |
24 | | |
25 | | enum llama_memory_status { |
26 | | LLAMA_MEMORY_STATUS_SUCCESS = 0, |
27 | | LLAMA_MEMORY_STATUS_NO_UPDATE, |
28 | | LLAMA_MEMORY_STATUS_FAILED_PREPARE, |
29 | | LLAMA_MEMORY_STATUS_FAILED_COMPUTE, |
30 | | }; |
31 | | |
32 | | // helper function for combining the status of two memory contexts |
33 | | // useful for implementing hybrid memory types (e.g. iSWA) |
34 | | llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); |
35 | | |
36 | | // helper function for checking if a memory status indicates a failure |
37 | | bool llama_memory_status_is_fail(llama_memory_status status); |
38 | | |
39 | | // the interface for managing the memory context during batch processing |
40 | | // this interface is implemented per memory type. see: |
41 | | // - llama_kv_cache_context |
42 | | // - llama_kv_cache_iswa_context |
43 | | // ... |
44 | | // |
45 | | // the only method that should mutate the memory and the memory context is llama_memory_i::apply() |
46 | | struct llama_memory_context_i { |
47 | 0 | virtual ~llama_memory_context_i() = default; |
48 | | |
49 | | // consume the current ubatch from the context and proceed to the next one |
50 | | // return false if we are done |
51 | | virtual bool next() = 0; |
52 | | |
53 | | // apply the memory state for the current ubatch to the memory object |
54 | | // return false on failure |
55 | | virtual bool apply() = 0; |
56 | | |
57 | | // get the current ubatch |
58 | | virtual const llama_ubatch & get_ubatch() const = 0; |
59 | | |
60 | | // get the status of the memory context - used for error handling and checking if any updates would be applied |
61 | | virtual llama_memory_status get_status() const = 0; |
62 | | }; |
63 | | |
64 | | using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>; |
65 | | |
66 | | // general concept of LLM memory |
67 | | // the KV cache is a type of LLM memory, but there can be other types |
68 | | struct llama_memory_i { |
69 | | // this callback is used to filter out layers that should not be included in the cache |
70 | | using layer_filter_cb = std::function<bool(int32_t il)>; |
71 | | |
72 | | // this callback is used to specify which layers should reuse memory from other layers |
73 | | // return negative value to indicate that the layer il should not reuse memory |
74 | | using layer_reuse_cb = std::function<int32_t(int32_t il)>; |
75 | | |
76 | 0 | virtual ~llama_memory_i() = default; |
77 | | |
78 | | // split the input batch into a set of ubatches and verify that they can fit into the cache |
79 | | // return a context object containing the ubatches and memory state required to process them |
80 | | // check the llama_memory_context_i::get_status() for the result |
81 | | virtual llama_memory_context_ptr init_batch( |
82 | | llama_batch_allocr & balloc, |
83 | | uint32_t n_ubatch, |
84 | | bool embd_all) = 0; |
85 | | |
86 | | // simulate full cache, used for allocating worst-case compute buffers |
87 | | virtual llama_memory_context_ptr init_full() = 0; |
88 | | |
89 | | // prepare for any pending memory updates, such as shifts, copies, etc. |
90 | | // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update |
91 | | virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0; |
92 | | |
93 | | // getters |
94 | | virtual bool get_can_shift() const = 0; |
95 | | |
96 | | // |
97 | | // ops |
98 | | // |
99 | | |
100 | | // if data == true, the data buffers will also be cleared together with the metadata |
101 | | virtual void clear(bool data) = 0; |
102 | | |
103 | | virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; |
104 | | virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; |
105 | | virtual void seq_keep(llama_seq_id seq_id) = 0; |
106 | | virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0; |
107 | | virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; |
108 | | |
109 | | virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; |
110 | | virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0; |
111 | | |
112 | | virtual std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const = 0; |
113 | | |
114 | | // |
115 | | // state write/read |
116 | | // |
117 | | |
118 | | virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0; |
119 | | virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0; |
120 | | }; |
121 | | |
122 | | using llama_memory_ptr = std::unique_ptr<llama_memory_i>; |