/src/llama.cpp/src/models/plamo2.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | |
3 | | #include "llama-memory-recurrent.h" |
4 | | |
5 | | llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : |
6 | 0 | llm_build_mamba_base(params) { |
7 | 0 | ggml_tensor * cur; |
8 | 0 | ggml_tensor * inpL; |
9 | | |
10 | | // {n_embd, n_tokens} |
11 | 0 | inpL = build_inp_embd(model.tok_embd); |
12 | 0 | cb(inpL, "embedding_output", -1); |
13 | |
|
14 | 0 | ggml_tensor * inp_pos = build_inp_pos(); |
15 | |
|
16 | 0 | auto * inp_hybrid = build_inp_mem_hybrid(); |
17 | |
|
18 | 0 | ggml_tensor * inp_out_ids = build_inp_out_ids(); |
19 | |
|
20 | 0 | for (int il = 0; il < n_layer; ++il) { |
21 | 0 | ggml_tensor * residual = inpL; |
22 | | |
23 | | // ggml_graph_add_node(gf, model.layers[il].attn_norm); |
24 | | // cb(model.layers[il].attn_norm, "attn_norm", il); |
25 | | |
26 | | // pre_mixer_norm |
27 | 0 | cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); |
28 | | |
29 | | // check if this layer is Mamba or Attention |
30 | 0 | const bool is_mamba_layer = hparams.is_recurrent(il); |
31 | |
|
32 | 0 | if (is_mamba_layer) { |
33 | | // PLaMo-2 Mamba layer |
34 | 0 | cur = build_plamo2_mamba_layer(inp_hybrid->get_recr(), cur, model, ubatch, il); |
35 | 0 | } else { |
36 | | // PLaMo-2 Attention layer |
37 | 0 | cur = build_plamo2_attn_layer(inp_hybrid->get_attn(), inp_pos, cur, model, il); |
38 | 0 | } |
39 | | |
40 | | // post_mixer_norm |
41 | 0 | cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); |
42 | 0 | cb(cur, "attn_post_norm", il); |
43 | | |
44 | | // residual connection |
45 | 0 | cur = ggml_add(ctx0, cur, residual); |
46 | 0 | cb(cur, "attn_residual", il); |
47 | 0 | residual = cur; |
48 | | |
49 | | // pre-ffn norm |
50 | 0 | cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); |
51 | 0 | cb(cur, "ffn_pre_norm", il); |
52 | | |
53 | | // feed-forward network |
54 | 0 | cur = build_ffn(cur, |
55 | 0 | model.layers[il].ffn_up, NULL, NULL, |
56 | 0 | NULL, NULL, NULL, |
57 | 0 | model.layers[il].ffn_down, NULL, NULL, |
58 | 0 | NULL, LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); |
59 | 0 | cb(cur, "ffn_out", il); |
60 | | |
61 | | // post ffn norm |
62 | 0 | cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); |
63 | 0 | cb(cur, "ffn_post_norm", il); |
64 | |
|
65 | 0 | if (il == n_layer - 1 && inp_out_ids) { |
66 | 0 | cur = ggml_get_rows(ctx0, cur, inp_out_ids); |
67 | 0 | residual = ggml_get_rows(ctx0, residual, inp_out_ids); |
68 | 0 | } |
69 | | |
70 | | // residual connection |
71 | 0 | cur = ggml_add(ctx0, cur, residual); |
72 | 0 | cb(cur, "ffn_residual", il); |
73 | | |
74 | | // input for next layer |
75 | 0 | inpL = cur; |
76 | 0 | } |
77 | |
|
78 | 0 | cur = inpL; |
79 | | |
80 | | // final norm |
81 | 0 | cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); |
82 | 0 | cb(cur, "result_norm", -1); |
83 | |
|
84 | 0 | res->t_embd = cur; |
85 | | |
86 | | // lm_head |
87 | 0 | cur = build_lora_mm(model.output, cur); |
88 | 0 | cb(cur, "result_output", -1); |
89 | | |
90 | | // Explicitly mark as output tensor to ensure proper backend assignment |
91 | 0 | ggml_set_output(cur); |
92 | |
|
93 | 0 | res->t_logits = cur; |
94 | |
|
95 | 0 | ggml_build_forward_expand(gf, cur); |
96 | 0 | } |
97 | | |
98 | | ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, |
99 | | ggml_tensor * inp_pos, |
100 | | ggml_tensor * cur, |
101 | | const llama_model & model, |
102 | 0 | int il) { |
103 | | // self-attention |
104 | 0 | { |
105 | | // PLaMo-2 uses combined QKV tensor |
106 | 0 | ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); |
107 | 0 | cb(qkv, "wqkv", il); |
108 | | |
109 | | // split QKV tensor into Q, K, V |
110 | 0 | const int64_t n_embd_head_q = hparams.n_embd_head_k(); |
111 | 0 | const int64_t n_embd_head_k = hparams.n_embd_head_k(); |
112 | 0 | const int64_t n_embd_head_v = hparams.n_embd_head_v(); |
113 | 0 | int32_t n_head = hparams.n_head(il); |
114 | 0 | int32_t n_head_kv = hparams.n_head_kv(il); |
115 | |
|
116 | 0 | const int64_t q_offset = 0; |
117 | 0 | const int64_t k_offset = n_embd_head_q * n_head; |
118 | 0 | const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv; |
119 | |
|
120 | 0 | ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), |
121 | 0 | qkv->nb[1], q_offset * ggml_element_size(qkv)); |
122 | 0 | ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), |
123 | 0 | qkv->nb[1], k_offset * ggml_element_size(qkv)); |
124 | 0 | ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, n_embd_head_v, n_head_kv, n_tokens, n_embd_head_v * sizeof(float), |
125 | 0 | qkv->nb[1], v_offset * ggml_element_size(qkv)); |
126 | |
|
127 | 0 | cb(Qcur, "Qcur", il); |
128 | 0 | cb(Kcur, "Kcur", il); |
129 | 0 | cb(Vcur, "Vcur", il); |
130 | |
|
131 | 0 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); |
132 | 0 | cb(Qcur, "Qcur_normed", il); |
133 | |
|
134 | 0 | Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, |
135 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
136 | |
|
137 | 0 | Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); |
138 | 0 | cb(Kcur, "Kcur_normed", il); |
139 | |
|
140 | 0 | Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, |
141 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
142 | |
|
143 | 0 | cur = build_attn(inp, |
144 | 0 | model.layers[il].wo, NULL, |
145 | 0 | Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f / sqrtf(float(n_embd_head_v)), il); |
146 | 0 | } |
147 | |
|
148 | 0 | cb(cur, "attn_out", il); |
149 | |
|
150 | 0 | return cur; |
151 | 0 | } |
152 | | |
153 | | ggml_tensor * llm_build_plamo2::build_plamo2_mamba_layer(llm_graph_input_rs * inp, |
154 | | ggml_tensor * cur, |
155 | | const llama_model & model, |
156 | | const llama_ubatch & ubatch, |
157 | 0 | int il) { |
158 | 0 | const auto * mctx_cur = inp->mctx; |
159 | |
|
160 | 0 | const auto kv_head = mctx_cur->get_head(); |
161 | |
|
162 | 0 | const int64_t d_conv = hparams.ssm_d_conv; |
163 | 0 | const int64_t d_inner = hparams.ssm_d_inner; |
164 | 0 | const int64_t d_state = hparams.ssm_d_state; |
165 | 0 | const int64_t n_heads = hparams.ssm_dt_rank; |
166 | 0 | const int64_t head_dim = d_inner / n_heads; |
167 | 0 | const int64_t n_group = hparams.ssm_n_group; |
168 | 0 | const int64_t n_seqs = ubatch.n_seqs; |
169 | |
|
170 | 0 | const int64_t n_seq_tokens = ubatch.n_seq_tokens; |
171 | |
|
172 | 0 | GGML_ASSERT(n_seqs != 0); |
173 | 0 | GGML_ASSERT(ubatch.equal_seqs()); |
174 | 0 | GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); |
175 | 0 | GGML_ASSERT(d_inner % n_head == 0); |
176 | 0 | GGML_ASSERT(n_group == 0); |
177 | |
|
178 | 0 | ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); |
179 | 0 | ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); |
180 | |
|
181 | 0 | ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); |
182 | 0 | conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs); |
183 | | |
184 | | // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} |
185 | 0 | cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); |
186 | | |
187 | | // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} |
188 | 0 | ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur); |
189 | 0 | cb(zx, "mamba_in_proj", il); |
190 | | // {8192, 5, 1, 1} -> {8192, 1, 5, 1} |
191 | 0 | zx = ggml_permute(ctx0, zx, 0, 2, 1, 3); |
192 | 0 | zx = ggml_cont_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); |
193 | 0 | cb(zx, "mamba_in_proj_out", il); |
194 | | |
195 | | // split into z and x |
196 | | // => {head_dim * n_heads, n_seq_tokens, n_seqs} |
197 | 0 | ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], |
198 | 0 | head_dim * ggml_element_size(zx)); |
199 | 0 | x = ggml_cont_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); |
200 | | // x = ggml_permute(ctx0, x, 0, 2, 1, 3); |
201 | 0 | cb(x, "mamba_x_split", il); |
202 | |
|
203 | 0 | ggml_tensor * z = |
204 | 0 | ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0); |
205 | 0 | cb(z, "mamba_z_split", il); |
206 | | |
207 | | // conv1d |
208 | 0 | { |
209 | | // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} |
210 | 0 | ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0); |
211 | 0 | cb(conv_x, "mamba_conv1d_input", il); |
212 | | |
213 | | // copy last (d_conv - 1) columns back into the state cache |
214 | 0 | ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], |
215 | 0 | n_seq_tokens * (conv_x->nb[0])); |
216 | |
|
217 | 0 | ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv, |
218 | 0 | ggml_view_1d(ctx0, conv_states_all, |
219 | 0 | (d_conv - 1) * (d_inner + 2 * n_group * d_state) * (n_seqs), |
220 | 0 | kv_head * (d_conv - 1) * (d_inner + 2 * n_group * d_state) * |
221 | 0 | ggml_element_size(conv_states_all)))); |
222 | 0 | cb(conv_states_all, "mamba_conv1d_state", il); |
223 | | |
224 | | // 1D convolution |
225 | 0 | x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); |
226 | 0 | cb(x, "mamba_conv1d", il); |
227 | |
|
228 | 0 | x = ggml_silu(ctx0, x); |
229 | 0 | cb(x, "mamba_conv1d_silu", il); |
230 | 0 | } |
231 | | |
232 | | // SSM |
233 | 0 | { |
234 | | // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} |
235 | 0 | ggml_tensor * x_bcdt = build_lora_mm(model.layers[il].ssm_x, x); |
236 | 0 | cb(x_bcdt, "mamba_bcdt_proj", il); |
237 | | |
238 | | // split into dt, B, C |
239 | 0 | const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); |
240 | 0 | ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0); |
241 | 0 | ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], |
242 | 0 | ggml_element_size(x_bcdt) * d_state); |
243 | 0 | ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], |
244 | 0 | ggml_element_size(x_bcdt) * (2 * d_state)); |
245 | 0 | cb(B, "mamba_B_raw", il); |
246 | 0 | cb(C, "mamba_C_raw", il); |
247 | 0 | cb(dt, "mamba_dt_raw", il); |
248 | | |
249 | | // Apply RMS norm to dt, B, C (PLaMo-2 specific) |
250 | 0 | B = build_norm(B, model.layers[il].ssm_b_norm, NULL, LLM_NORM_RMS, il); |
251 | 0 | C = build_norm(C, model.layers[il].ssm_c_norm, NULL, LLM_NORM_RMS, il); |
252 | 0 | dt = build_norm(dt, model.layers[il].ssm_dt_norm, NULL, LLM_NORM_RMS, il); |
253 | 0 | cb(B, "mamba_B_normed", il); |
254 | 0 | cb(C, "mamba_C_normed", il); |
255 | 0 | cb(dt, "mamba_dt_normed", il); |
256 | | |
257 | | // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} |
258 | 0 | dt = build_lora_mm(model.layers[il].ssm_dt, dt); |
259 | 0 | dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); |
260 | 0 | cb(dt, "mamba_dt_proj", il); |
261 | |
|
262 | 0 | ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads); |
263 | 0 | cb(A, "mamba_A", il); |
264 | |
|
265 | 0 | x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), |
266 | 0 | head_dim * n_heads * ggml_element_size(x), |
267 | 0 | head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); |
268 | 0 | B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0); |
269 | 0 | C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0); |
270 | | |
271 | | // use the states and the indices provided by build_recurrent_state |
272 | | // (this is necessary in order to properly use the states before they are overwritten, |
273 | | // while avoiding to make unnecessary copies of the states) |
274 | 0 | auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { |
275 | 0 | ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_heads, mctx_cur->get_size()); |
276 | | |
277 | | // Custom operator to optimize the parallel associative scan |
278 | | // as described in the Annex D of the Mamba paper. |
279 | | // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} |
280 | 0 | return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); |
281 | 0 | }; |
282 | |
|
283 | 0 | ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); |
284 | 0 | cb(y_ssm, "mamba_ssm_scan", il); |
285 | | |
286 | | // store last states |
287 | 0 | ggml_build_forward_expand( |
288 | 0 | gf, ggml_cpy( |
289 | 0 | ctx0, |
290 | 0 | ggml_view_1d(ctx0, y_ssm, n_heads * head_dim * d_state * n_seqs, |
291 | 0 | n_heads * head_dim * n_seq_tokens * n_seqs * ggml_element_size(y_ssm)), |
292 | 0 | ggml_view_1d(ctx0, ssm_states_all, n_heads * head_dim * d_state * n_seqs, |
293 | 0 | kv_head * n_seqs * n_heads * head_dim * d_state * ggml_element_size(ssm_states_all)))); |
294 | 0 | cb(ssm_states_all, "mamba_ssm_states", il); |
295 | |
|
296 | 0 | ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, |
297 | 0 | head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), |
298 | 0 | head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); |
299 | 0 | cb(y, "mamba_y_view", il); |
300 | | |
301 | | // Add D parameter and apply gating with z |
302 | | // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} |
303 | 0 | ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads); |
304 | 0 | y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D)); |
305 | 0 | cb(y, "mamba_y_add_d", il); |
306 | |
|
307 | 0 | y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); |
308 | 0 | cb(y, "mamba_y_swiglu_z", il); |
309 | | |
310 | | // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} |
311 | 0 | y = ggml_view_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs, y->nb[2], y->nb[3], 0); |
312 | 0 | cur = build_lora_mm(model.layers[il].ssm_out, y); |
313 | 0 | cb(cur, "mamba_out_proj", il); |
314 | 0 | } |
315 | | |
316 | | // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} |
317 | 0 | cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); |
318 | 0 | cb(cur, "mamba_out", il); |
319 | |
|
320 | 0 | return cur; |
321 | 0 | } |