/src/llama.cpp/src/models/plamo2.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | |
3 | | #include "llama-memory-recurrent.h" |
4 | | |
5 | | llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : |
6 | 0 | llm_build_mamba_base(params) { |
7 | 0 | ggml_tensor * cur; |
8 | 0 | ggml_tensor * inpL; |
9 | | |
10 | | // {n_embd, n_tokens} |
11 | 0 | inpL = build_inp_embd(model.tok_embd); |
12 | 0 | cb(inpL, "embedding_output", -1); |
13 | |
|
14 | 0 | ggml_tensor * inp_pos = build_inp_pos(); |
15 | |
|
16 | 0 | auto * inp_hybrid = build_inp_mem_hybrid(); |
17 | |
|
18 | 0 | ggml_tensor * inp_out_ids = build_inp_out_ids(); |
19 | |
|
20 | 0 | for (int il = 0; il < n_layer; ++il) { |
21 | 0 | ggml_tensor * residual = inpL; |
22 | | |
23 | | // ggml_graph_add_node(gf, model.layers[il].attn_norm); |
24 | | // cb(model.layers[il].attn_norm, "attn_norm", il); |
25 | | |
26 | | // pre_mixer_norm |
27 | 0 | cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); |
28 | | |
29 | | // check if this layer is Mamba or Attention |
30 | 0 | bool is_mamba_layer = hparams.is_recurrent(il); |
31 | |
|
32 | 0 | if (is_mamba_layer) { |
33 | | // PLaMo-2 Mamba layer |
34 | 0 | cur = build_plamo2_mamba_layer(inp_hybrid->get_recr(), cur, model, ubatch, il); |
35 | 0 | } else { |
36 | | // PLaMo-2 Attention layer |
37 | 0 | cur = build_plamo2_attn_layer(inp_hybrid->get_attn(), inp_pos, cur, model, il); |
38 | 0 | } |
39 | | |
40 | | // post_mixer_norm |
41 | 0 | cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); |
42 | 0 | cb(cur, "attn_post_norm", il); |
43 | | |
44 | | // residual connection |
45 | 0 | cur = ggml_add(ctx0, cur, residual); |
46 | 0 | cb(cur, "attn_residual", il); |
47 | 0 | residual = cur; |
48 | | |
49 | | // pre-ffn norm |
50 | 0 | cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); |
51 | 0 | cb(cur, "ffn_pre_norm", il); |
52 | | |
53 | | // feed-forward network |
54 | 0 | cur = build_ffn(cur, |
55 | 0 | model.layers[il].ffn_up, NULL, NULL, |
56 | 0 | NULL, NULL, NULL, |
57 | 0 | model.layers[il].ffn_down, NULL, NULL, |
58 | 0 | NULL, LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); |
59 | 0 | cb(cur, "ffn_out", il); |
60 | | |
61 | | // post ffn norm |
62 | 0 | cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); |
63 | 0 | cb(cur, "ffn_post_norm", il); |
64 | |
|
65 | 0 | if (il == n_layer - 1 && inp_out_ids) { |
66 | 0 | cur = ggml_get_rows(ctx0, cur, inp_out_ids); |
67 | 0 | residual = ggml_get_rows(ctx0, residual, inp_out_ids); |
68 | 0 | } |
69 | | |
70 | | // residual connection |
71 | 0 | cur = ggml_add(ctx0, cur, residual); |
72 | 0 | cb(cur, "ffn_residual", il); |
73 | |
|
74 | 0 | inpL = cur; |
75 | 0 | } |
76 | |
|
77 | 0 | cur = inpL; |
78 | | |
79 | | // final norm |
80 | 0 | cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); |
81 | 0 | cb(cur, "result_norm", -1); |
82 | |
|
83 | 0 | res->t_embd = cur; |
84 | | |
85 | | // lm_head |
86 | 0 | cur = build_lora_mm(model.output, cur); |
87 | 0 | cb(cur, "result_output", -1); |
88 | | |
89 | | // Explicitly mark as output tensor to ensure proper backend assignment |
90 | 0 | ggml_set_output(cur); |
91 | |
|
92 | 0 | res->t_logits = cur; |
93 | |
|
94 | 0 | ggml_build_forward_expand(gf, cur); |
95 | 0 | } |
96 | | |
97 | | ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, |
98 | | ggml_tensor * inp_pos, |
99 | | ggml_tensor * cur, |
100 | | const llama_model & model, |
101 | 0 | int il) { |
102 | | // self-attention |
103 | 0 | { |
104 | | // PLaMo-2 uses combined QKV tensor |
105 | 0 | ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); |
106 | 0 | cb(qkv, "wqkv", il); |
107 | | |
108 | | // split QKV tensor into Q, K, V |
109 | 0 | const int64_t n_embd_head_q = hparams.n_embd_head_k; |
110 | 0 | const int64_t n_embd_head_k = hparams.n_embd_head_k; |
111 | 0 | const int64_t n_embd_head_v = hparams.n_embd_head_v; |
112 | 0 | int32_t n_head = hparams.n_head(il); |
113 | 0 | int32_t n_head_kv = hparams.n_head_kv(il); |
114 | |
|
115 | 0 | const int64_t q_offset = 0; |
116 | 0 | const int64_t k_offset = n_embd_head_q * n_head; |
117 | 0 | const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv; |
118 | |
|
119 | 0 | ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), |
120 | 0 | qkv->nb[1], q_offset * ggml_element_size(qkv)); |
121 | 0 | ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), |
122 | 0 | qkv->nb[1], k_offset * ggml_element_size(qkv)); |
123 | 0 | ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, n_embd_head_v, n_head_kv, n_tokens, n_embd_head_v * sizeof(float), |
124 | 0 | qkv->nb[1], v_offset * ggml_element_size(qkv)); |
125 | |
|
126 | 0 | cb(Qcur, "Qcur", il); |
127 | 0 | cb(Kcur, "Kcur", il); |
128 | 0 | cb(Vcur, "Vcur", il); |
129 | |
|
130 | 0 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); |
131 | 0 | cb(Qcur, "Qcur_normed", il); |
132 | |
|
133 | 0 | Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, |
134 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
135 | |
|
136 | 0 | Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); |
137 | 0 | cb(Kcur, "Kcur_normed", il); |
138 | |
|
139 | 0 | Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, |
140 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
141 | |
|
142 | 0 | cur = build_attn(inp, |
143 | 0 | model.layers[il].wo, NULL, |
144 | 0 | Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f / sqrtf(float(n_embd_head_v)), il); |
145 | 0 | } |
146 | |
|
147 | 0 | cb(cur, "attn_out", il); |
148 | |
|
149 | 0 | return cur; |
150 | 0 | } |
151 | | |
152 | | ggml_tensor * llm_build_plamo2::build_plamo2_mamba_layer(llm_graph_input_rs * inp, |
153 | | ggml_tensor * cur, |
154 | | const llama_model & model, |
155 | | const llama_ubatch & ubatch, |
156 | 0 | int il) { |
157 | 0 | const auto * mctx_cur = inp->mctx; |
158 | |
|
159 | 0 | const auto kv_head = mctx_cur->get_head(); |
160 | |
|
161 | 0 | const int64_t d_conv = hparams.ssm_d_conv; |
162 | 0 | const int64_t d_inner = hparams.ssm_d_inner; |
163 | 0 | const int64_t d_state = hparams.ssm_d_state; |
164 | 0 | const int64_t n_heads = hparams.ssm_dt_rank; |
165 | 0 | const int64_t head_dim = d_inner / n_heads; |
166 | 0 | const int64_t n_group = hparams.ssm_n_group; |
167 | 0 | const int64_t n_seqs = ubatch.n_seqs; |
168 | |
|
169 | 0 | const int64_t n_seq_tokens = ubatch.n_seq_tokens; |
170 | |
|
171 | 0 | GGML_ASSERT(n_seqs != 0); |
172 | 0 | GGML_ASSERT(ubatch.equal_seqs()); |
173 | 0 | GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); |
174 | |
|
175 | 0 | ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); |
176 | 0 | ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); |
177 | |
|
178 | 0 | ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); |
179 | 0 | conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs); |
180 | | |
181 | | // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} |
182 | 0 | cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); |
183 | | |
184 | | // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} |
185 | 0 | ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur); |
186 | 0 | cb(zx, "mamba_in_proj", il); |
187 | | // {8192, 5, 1, 1} -> {8192, 1, 5, 1} |
188 | 0 | zx = ggml_permute(ctx0, zx, 0, 2, 1, 3); |
189 | 0 | zx = ggml_cont_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); |
190 | 0 | cb(zx, "mamba_in_proj_out", il); |
191 | | |
192 | | // split into z and x |
193 | | // => {head_dim * n_heads, n_seq_tokens, n_seqs} |
194 | 0 | ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], |
195 | 0 | head_dim * ggml_element_size(zx)); |
196 | 0 | x = ggml_cont_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); |
197 | | // x = ggml_permute(ctx0, x, 0, 2, 1, 3); |
198 | 0 | cb(x, "mamba_x_split", il); |
199 | |
|
200 | 0 | ggml_tensor * z = |
201 | 0 | ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0); |
202 | 0 | cb(z, "mamba_z_split", il); |
203 | | |
204 | | // conv1d |
205 | 0 | { |
206 | | // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} |
207 | 0 | ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0); |
208 | 0 | cb(conv_x, "mamba_conv1d_input", il); |
209 | | |
210 | | // copy last (d_conv - 1) columns back into the state cache |
211 | 0 | ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], |
212 | 0 | n_seq_tokens * (conv_x->nb[0])); |
213 | |
|
214 | 0 | ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv, |
215 | 0 | ggml_view_1d(ctx0, conv_states_all, |
216 | 0 | (d_conv - 1) * (d_inner + 2 * n_group * d_state) * (n_seqs), |
217 | 0 | kv_head * (d_conv - 1) * (d_inner + 2 * n_group * d_state) * |
218 | 0 | ggml_element_size(conv_states_all)))); |
219 | 0 | cb(conv_states_all, "mamba_conv1d_state", il); |
220 | | |
221 | | // 1D convolution |
222 | 0 | x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); |
223 | 0 | cb(x, "mamba_conv1d", il); |
224 | |
|
225 | 0 | x = ggml_silu(ctx0, x); |
226 | 0 | cb(x, "mamba_conv1d_silu", il); |
227 | 0 | } |
228 | | |
229 | | // SSM |
230 | 0 | { |
231 | | // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} |
232 | 0 | ggml_tensor * x_bcdt = build_lora_mm(model.layers[il].ssm_x, x); |
233 | 0 | cb(x_bcdt, "mamba_bcdt_proj", il); |
234 | | |
235 | | // split into dt, B, C |
236 | 0 | const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); |
237 | 0 | ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0); |
238 | 0 | ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], |
239 | 0 | ggml_element_size(x_bcdt) * d_state); |
240 | 0 | ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], |
241 | 0 | ggml_element_size(x_bcdt) * (2 * d_state)); |
242 | 0 | cb(B, "mamba_B_raw", il); |
243 | 0 | cb(C, "mamba_C_raw", il); |
244 | 0 | cb(dt, "mamba_dt_raw", il); |
245 | | |
246 | | // Apply RMS norm to dt, B, C (PLaMo-2 specific) |
247 | 0 | B = build_norm(B, model.layers[il].ssm_b_norm, NULL, LLM_NORM_RMS, il); |
248 | 0 | C = build_norm(C, model.layers[il].ssm_c_norm, NULL, LLM_NORM_RMS, il); |
249 | 0 | dt = build_norm(dt, model.layers[il].ssm_dt_norm, NULL, LLM_NORM_RMS, il); |
250 | 0 | cb(B, "mamba_B_normed", il); |
251 | 0 | cb(C, "mamba_C_normed", il); |
252 | 0 | cb(dt, "mamba_dt_normed", il); |
253 | | |
254 | | // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} |
255 | 0 | dt = build_lora_mm(model.layers[il].ssm_dt, dt); |
256 | 0 | dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); |
257 | 0 | cb(dt, "mamba_dt_proj", il); |
258 | |
|
259 | 0 | ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads); |
260 | 0 | cb(A, "mamba_A", il); |
261 | |
|
262 | 0 | x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), |
263 | 0 | head_dim * n_heads * ggml_element_size(x), |
264 | 0 | head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); |
265 | 0 | B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0); |
266 | 0 | C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0); |
267 | | |
268 | | // use the states and the indices provided by build_recurrent_state |
269 | | // (this is necessary in order to properly use the states before they are overwritten, |
270 | | // while avoiding to make unnecessary copies of the states) |
271 | 0 | auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { |
272 | 0 | ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_heads, mctx_cur->get_size()); |
273 | | |
274 | | // Custom operator to optimize the parallel associative scan |
275 | | // as described in the Annex D of the Mamba paper. |
276 | | // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} |
277 | 0 | return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); |
278 | 0 | }; |
279 | |
|
280 | 0 | ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); |
281 | 0 | cb(y_ssm, "mamba_ssm_scan", il); |
282 | | |
283 | | // store last states |
284 | 0 | ggml_build_forward_expand( |
285 | 0 | gf, ggml_cpy( |
286 | 0 | ctx0, |
287 | 0 | ggml_view_1d(ctx0, y_ssm, n_heads * head_dim * d_state * n_seqs, |
288 | 0 | n_heads * head_dim * n_seq_tokens * n_seqs * ggml_element_size(y_ssm)), |
289 | 0 | ggml_view_1d(ctx0, ssm_states_all, n_heads * head_dim * d_state * n_seqs, |
290 | 0 | kv_head * n_seqs * n_heads * head_dim * d_state * ggml_element_size(ssm_states_all)))); |
291 | 0 | cb(ssm_states_all, "mamba_ssm_states", il); |
292 | |
|
293 | 0 | ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, |
294 | 0 | head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), |
295 | 0 | head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); |
296 | 0 | cb(y, "mamba_y_view", il); |
297 | | |
298 | | // Add D parameter and apply gating with z |
299 | | // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} |
300 | 0 | ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads); |
301 | 0 | y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D)); |
302 | 0 | cb(y, "mamba_y_add_d", il); |
303 | |
|
304 | 0 | y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); |
305 | 0 | cb(y, "mamba_y_swiglu_z", il); |
306 | | |
307 | | // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} |
308 | 0 | y = ggml_view_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs, y->nb[2], y->nb[3], 0); |
309 | 0 | cur = build_lora_mm(model.layers[il].ssm_out, y); |
310 | 0 | cb(cur, "mamba_out_proj", il); |
311 | 0 | } |
312 | | |
313 | | // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} |
314 | 0 | cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); |
315 | 0 | cb(cur, "mamba_out", il); |
316 | |
|
317 | 0 | return cur; |
318 | 0 | } |