/src/llama.cpp/src/llama-memory.h
Line | Count | Source |
1 | | #pragma once |
2 | | |
3 | | #include "llama.h" |
4 | | #include "llama-graph.h" |
5 | | |
6 | | #include <map> |
7 | | #include <memory> |
8 | | #include <functional> |
9 | | |
10 | | struct llama_ubatch; |
11 | | |
12 | | class llama_batch_allocr; |
13 | | |
14 | | class llama_io_write_i; |
15 | | class llama_io_read_i; |
16 | | |
17 | | struct llama_memory_params { |
18 | | // kv cache |
19 | | ggml_type type_k; |
20 | | ggml_type type_v; |
21 | | |
22 | | // use full-size SWA cache |
23 | | bool swa_full; |
24 | | |
25 | | llama_context_type ctx_type; |
26 | | |
27 | | llama_memory_t mem_other; |
28 | | }; |
29 | | |
30 | | enum llama_memory_status { |
31 | | LLAMA_MEMORY_STATUS_SUCCESS = 0, |
32 | | LLAMA_MEMORY_STATUS_NO_UPDATE, |
33 | | LLAMA_MEMORY_STATUS_FAILED_PREPARE, |
34 | | LLAMA_MEMORY_STATUS_FAILED_COMPUTE, |
35 | | }; |
36 | | |
37 | | // helper function for combining the status of two memory contexts |
38 | | // useful for implementing hybrid memory types (e.g. iSWA) |
39 | | llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); |
40 | | |
41 | | // helper function for checking if a memory status indicates a failure |
42 | | bool llama_memory_status_is_fail(llama_memory_status status); |
43 | | |
44 | | // the interface for managing the memory context during batch processing |
45 | | // this interface is implemented per memory type. see: |
46 | | // - llama_kv_cache_context |
47 | | // - llama_kv_cache_iswa_context |
48 | | // ... |
49 | | // |
50 | | // the only method that should mutate the memory and the memory context is llama_memory_i::apply() |
51 | | struct llama_memory_context_i { |
52 | 0 | virtual ~llama_memory_context_i() = default; |
53 | | |
54 | | // consume the current ubatch from the context and proceed to the next one |
55 | | // return false if we are done |
56 | | virtual bool next() = 0; |
57 | | |
58 | | // apply the memory state for the current ubatch to the memory object |
59 | | // return false on failure |
60 | | virtual bool apply() = 0; |
61 | | |
62 | | // get the current ubatch |
63 | | virtual const llama_ubatch & get_ubatch() const = 0; |
64 | | |
65 | | // get the status of the memory context - used for error handling and checking if any updates would be applied |
66 | | virtual llama_memory_status get_status() const = 0; |
67 | | }; |
68 | | |
69 | | using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>; |
70 | | |
71 | | // general concept of LLM memory |
72 | | // the KV cache is a type of LLM memory, but there can be other types |
73 | | struct llama_memory_i { |
74 | | // this callback is used to filter out layers that should not be included in the cache |
75 | | using layer_filter_cb = std::function<bool(int32_t il)>; |
76 | | |
77 | | // this callback is used to specify which layers should reuse memory from other layers |
78 | | // return negative value to indicate that the layer il should not reuse memory |
79 | | using layer_reuse_cb = std::function<int32_t(int32_t il)>; |
80 | | |
81 | | using layer_share_cb = std::function<int32_t(int32_t il)>; |
82 | | |
83 | 0 | virtual ~llama_memory_i() = default; |
84 | | |
85 | | // split the input batch into a set of ubatches and verify that they can fit into the cache |
86 | | // return a context object containing the ubatches and memory state required to process them |
87 | | // check the llama_memory_context_i::get_status() for the result |
88 | | virtual llama_memory_context_ptr init_batch( |
89 | | llama_batch_allocr & balloc, |
90 | | uint32_t n_ubatch, |
91 | | bool embd_all) = 0; |
92 | | |
93 | | // simulate full cache, used for allocating worst-case compute buffers |
94 | | virtual llama_memory_context_ptr init_full() = 0; |
95 | | |
96 | | // prepare for any pending memory updates, such as shifts, copies, etc. |
97 | | // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update |
98 | | virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0; |
99 | | |
100 | | // getters |
101 | | virtual bool get_can_shift() const = 0; |
102 | | |
103 | | // |
104 | | // ops |
105 | | // |
106 | | |
107 | | // if data == true, the data buffers will also be cleared together with the metadata |
108 | | virtual void clear(bool data) = 0; |
109 | | |
110 | | virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; |
111 | | virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; |
112 | | virtual void seq_keep(llama_seq_id seq_id) = 0; |
113 | | virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0; |
114 | | virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; |
115 | | |
116 | | virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; |
117 | | virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0; |
118 | | |
119 | | virtual std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const = 0; |
120 | | |
121 | | // |
122 | | // state write/read |
123 | | // |
124 | | |
125 | | virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0; |
126 | | virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0; |
127 | | }; |
128 | | |
129 | | using llama_memory_ptr = std::unique_ptr<llama_memory_i>; |