/src/llama.cpp/src/models/mamba-base.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | |
3 | | #include "llama-memory-recurrent.h" |
4 | | |
5 | 0 | llm_build_mamba_base::llm_build_mamba_base(const llm_graph_params & params) : llm_graph_context(params) {} |
6 | | |
7 | | ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp, |
8 | | ggml_tensor * cur, |
9 | | const llama_model & model, |
10 | | const llama_ubatch & ubatch, |
11 | 0 | int il) { |
12 | 0 | const auto * mctx_cur = inp->mctx; |
13 | |
|
14 | 0 | const auto kv_head = mctx_cur->get_head(); |
15 | |
|
16 | 0 | const auto & layer = model.layers[il]; |
17 | |
|
18 | 0 | const int64_t d_conv = hparams.ssm_d_conv; |
19 | 0 | const int64_t d_inner = hparams.ssm_d_inner; |
20 | 0 | const int64_t d_state = hparams.ssm_d_state; |
21 | 0 | const int64_t dt_rank = hparams.ssm_dt_rank; |
22 | 0 | const int64_t n_head = d_inner; |
23 | 0 | const int64_t head_dim = 1; |
24 | 0 | const int64_t n_seqs = ubatch.n_seqs; |
25 | | // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) |
26 | 0 | const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; |
27 | |
|
28 | 0 | const int64_t n_seq_tokens = ubatch.n_seq_tokens; |
29 | |
|
30 | 0 | GGML_ASSERT(n_seqs != 0); |
31 | 0 | GGML_ASSERT(ubatch.equal_seqs()); |
32 | 0 | GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); |
33 | 0 | GGML_ASSERT(d_inner % n_head == 0); |
34 | |
|
35 | 0 | ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); |
36 | 0 | ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); |
37 | |
|
38 | 0 | ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); |
39 | 0 | conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); |
40 | | |
41 | | // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} |
42 | 0 | cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); |
43 | | |
44 | | // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} |
45 | 0 | ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur, layer.ssm_in_s); |
46 | | // split the above in two |
47 | | // => {d_inner, n_seq_tokens, n_seqs} |
48 | 0 | ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); |
49 | 0 | ggml_tensor * z = |
50 | 0 | ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner * ggml_element_size(xz)); |
51 | | |
52 | | // conv |
53 | 0 | { |
54 | | // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} |
55 | 0 | ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0); |
56 | | |
57 | | // copy last (d_conv - 1) columns back into the state cache |
58 | 0 | ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], |
59 | 0 | n_seq_tokens * (conv_x->nb[0])); |
60 | |
|
61 | 0 | ggml_build_forward_expand( |
62 | 0 | gf, ggml_cpy(ctx0, last_conv, |
63 | 0 | ggml_view_1d(ctx0, conv_states_all, (d_conv - 1) * (d_inner) * (n_seqs), |
64 | 0 | kv_head * (d_conv - 1) * (d_inner) *ggml_element_size(conv_states_all)))); |
65 | | |
66 | | // 1D convolution |
67 | | // The equivalent is to make a self-overlapping view of conv_x |
68 | | // over d_conv columns at each stride in the 3rd dimension, |
69 | | // then element-wise multiply that with the conv1d weight, |
70 | | // then sum the elements of each row, |
71 | | // (the last two steps are a dot product over rows (also doable with mul_mat)) |
72 | | // then permute away the ne[0] dimension, |
73 | | // and then you're left with the resulting x tensor. |
74 | | // For simultaneous sequences, all sequences need to have the same length. |
75 | 0 | x = ggml_ssm_conv(ctx0, conv_x, layer.ssm_conv1d); |
76 | | |
77 | | // bias |
78 | 0 | x = ggml_add(ctx0, x, layer.ssm_conv1d_b); |
79 | |
|
80 | 0 | x = ggml_silu(ctx0, x); |
81 | 0 | } |
82 | | |
83 | | // ssm |
84 | 0 | { |
85 | | // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} |
86 | 0 | ggml_tensor * x_db = build_lora_mm(layer.ssm_x, x); |
87 | | // split |
88 | 0 | ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); |
89 | 0 | ggml_tensor * B = |
90 | 0 | ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state * x_db->nb[0], x_db->nb[1], |
91 | 0 | x_db->nb[2], ggml_element_size(x_db) * dt_rank); |
92 | 0 | ggml_tensor * C = |
93 | 0 | ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state * x_db->nb[0], x_db->nb[1], |
94 | 0 | x_db->nb[2], ggml_element_size(x_db) * (dt_rank + d_state)); |
95 | | |
96 | | // Some Mamba variants (e.g. FalconMamba, Jamba) apply RMS norm in B, C & Dt layers |
97 | 0 | if (ssm_dt_b_c_rms || (layer.ssm_dt_norm && layer.ssm_b_norm && layer.ssm_c_norm)) { |
98 | 0 | dt = build_norm(dt, layer.ssm_dt_norm, NULL, LLM_NORM_RMS, il); |
99 | 0 | B = build_norm(B, layer.ssm_b_norm, NULL, LLM_NORM_RMS, il); |
100 | 0 | C = build_norm(C, layer.ssm_c_norm, NULL, LLM_NORM_RMS, il); |
101 | 0 | } |
102 | | |
103 | | // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} |
104 | 0 | dt = build_lora_mm(layer.ssm_dt, dt); |
105 | 0 | dt = ggml_add(ctx0, dt, layer.ssm_dt_b); |
106 | |
|
107 | 0 | cur = x; |
108 | 0 | x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs); |
109 | |
|
110 | 0 | ggml_tensor * A = layer.ssm_a; |
111 | | |
112 | | // use the states and the indices provided by build_recurrent_state |
113 | | // (this is necessary in order to properly use the states before they are overwritten, |
114 | | // while avoiding to make unnecessary copies of the states) |
115 | 0 | auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { |
116 | 0 | ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); |
117 | | |
118 | | // Custom operator to optimize the parallel associative scan |
119 | | // as described in the Annex D of the Mamba paper. |
120 | | // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} |
121 | 0 | return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); |
122 | 0 | }; |
123 | |
|
124 | 0 | ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); |
125 | | |
126 | | // store last states |
127 | 0 | ggml_build_forward_expand( |
128 | 0 | gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, x->nb[3] * x->ne[3]), |
129 | 0 | ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs, |
130 | 0 | kv_head * d_state * d_inner * ggml_element_size(ssm_states_all)))); |
131 | |
|
132 | 0 | ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0); |
133 | | |
134 | | // TODO: skip computing output earlier for unused tokens |
135 | |
|
136 | 0 | y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d)); |
137 | 0 | y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); |
138 | | |
139 | | // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} |
140 | 0 | cur = build_lora_mm(layer.ssm_out, y, layer.ssm_out_s); |
141 | 0 | } |
142 | | |
143 | | // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} |
144 | 0 | cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); |
145 | |
|
146 | 0 | return cur; |
147 | 0 | } |
148 | | |
149 | | ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp, |
150 | | ggml_tensor * cur, |
151 | | const llama_model & model, |
152 | | const llama_ubatch & ubatch, |
153 | 0 | int il) const { |
154 | 0 | const auto * mctx_cur = inp->mctx; |
155 | |
|
156 | 0 | const auto kv_head = mctx_cur->get_head(); |
157 | |
|
158 | 0 | const int64_t d_conv = hparams.ssm_d_conv; |
159 | 0 | const int64_t d_inner = hparams.ssm_d_inner; |
160 | 0 | const int64_t d_state = hparams.ssm_d_state; |
161 | 0 | const int64_t n_head = hparams.ssm_dt_rank; |
162 | 0 | const int64_t head_dim = d_inner / n_head; |
163 | 0 | const int64_t n_group = hparams.ssm_n_group; |
164 | 0 | const int64_t n_seqs = ubatch.n_seqs; |
165 | |
|
166 | 0 | const int64_t n_seq_tokens = ubatch.n_seq_tokens; |
167 | |
|
168 | 0 | GGML_ASSERT(n_seqs != 0); |
169 | 0 | GGML_ASSERT(ubatch.equal_seqs()); |
170 | 0 | GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); |
171 | 0 | GGML_ASSERT(d_inner % n_head == 0); |
172 | 0 | GGML_ASSERT(d_inner % d_state == 0); |
173 | 0 | GGML_ASSERT(d_inner % n_group == 0); |
174 | |
|
175 | 0 | ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); |
176 | 0 | ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); |
177 | |
|
178 | 0 | ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); |
179 | 0 | conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs); |
180 | | |
181 | | // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} |
182 | 0 | cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); |
183 | | |
184 | | // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads |
185 | | |
186 | | // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} |
187 | 0 | ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur, model.layers[il].ssm_in_s); |
188 | | |
189 | | // split the above in three |
190 | 0 | ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * zxBCdt->nb[0], |
191 | 0 | zxBCdt->nb[1], zxBCdt->nb[2], 0); |
192 | 0 | ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2 * n_group * d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], |
193 | 0 | zxBCdt->nb[2], d_inner * ggml_element_size(zxBCdt)); |
194 | 0 | ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], |
195 | 0 | (2 * d_inner + 2 * n_group * d_state) * ggml_element_size(zxBCdt)); |
196 | | |
197 | | // conv |
198 | 0 | { |
199 | | // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} |
200 | 0 | ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0); |
201 | | |
202 | | // copy last (d_conv - 1) columns back into the state cache |
203 | 0 | ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs, |
204 | 0 | conv_x->nb[1], conv_x->nb[2], n_seq_tokens * (conv_x->nb[0])); |
205 | |
|
206 | 0 | ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv, |
207 | 0 | ggml_view_1d(ctx0, conv_states_all, |
208 | 0 | (d_conv - 1) * (d_inner + 2 * n_group * d_state) * (n_seqs), |
209 | 0 | kv_head * (d_conv - 1) * (d_inner + 2 * n_group * d_state) * |
210 | 0 | ggml_element_size(conv_states_all)))); |
211 | | |
212 | | // 1D convolution |
213 | | // The equivalent is to make a self-overlapping view of conv_x |
214 | | // over d_conv columns at each stride in the 3rd dimension, |
215 | | // then element-wise multiply that with the conv1d weight, |
216 | | // then sum the elements of each row, |
217 | | // (the last two steps are a dot product over rows (also doable with mul_mat)) |
218 | | // then permute away the ne[0] dimension, |
219 | | // and then you're left with the resulting x tensor. |
220 | | // For simultaneous sequences, all sequences need to have the same length. |
221 | 0 | xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); |
222 | | |
223 | | // bias |
224 | 0 | xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); |
225 | |
|
226 | 0 | xBC = ggml_silu(ctx0, xBC); |
227 | 0 | } |
228 | | |
229 | | // ssm |
230 | 0 | { |
231 | | // These correspond to V K Q in SSM/attention duality |
232 | 0 | ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * xBC->nb[0], |
233 | 0 | xBC->nb[1], xBC->nb[2], 0); |
234 | 0 | ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0], |
235 | 0 | xBC->nb[1], xBC->nb[2], d_inner * ggml_element_size(xBC)); |
236 | 0 | ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0], |
237 | 0 | xBC->nb[1], xBC->nb[2], (d_inner + n_group * d_state) * ggml_element_size(xBC)); |
238 | | |
239 | | // {n_head, n_seq_tokens, n_seqs} |
240 | 0 | dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); |
241 | |
|
242 | 0 | ggml_tensor * A = model.layers[il].ssm_a; |
243 | | |
244 | | // use the states and the indices provided by build_recurrent_state |
245 | | // (this is necessary in order to properly use the states before they are overwritten, |
246 | | // while avoiding to make unnecessary copies of the states) |
247 | 0 | auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { |
248 | 0 | ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); |
249 | | |
250 | | // TODO: use semistructured matrices to implement state-space duality |
251 | | // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} |
252 | 0 | return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); |
253 | 0 | }; |
254 | |
|
255 | 0 | ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); |
256 | | |
257 | | // store last states |
258 | 0 | ggml_build_forward_expand( |
259 | 0 | gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, ggml_nelements(x) * x->nb[0]), |
260 | 0 | ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs, |
261 | 0 | kv_head * d_state * d_inner * ggml_element_size(ssm_states_all)))); |
262 | |
|
263 | 0 | ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head * x->nb[1], |
264 | 0 | n_seq_tokens * n_head * x->nb[1], 0); |
265 | | |
266 | | // TODO: skip computing output earlier for unused tokens |
267 | |
|
268 | 0 | y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); |
269 | 0 | cb(y, "mamba2_y_add_d", il); |
270 | 0 | y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); |
271 | | |
272 | | // grouped RMS norm |
273 | 0 | if (model.layers[il].ssm_norm) { |
274 | 0 | y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); |
275 | 0 | y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); |
276 | 0 | } |
277 | |
|
278 | 0 | y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); |
279 | | |
280 | | // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} |
281 | 0 | cur = build_lora_mm(model.layers[il].ssm_out, y, model.layers[il].ssm_out_s); |
282 | 0 | } |
283 | | |
284 | | // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} |
285 | 0 | cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); |
286 | 0 | cb(cur, "mamba_out", il); |
287 | |
|
288 | 0 | return cur; |
289 | 0 | } |