/src/llama.cpp/src/models/plamo2.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | |
3 | | llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : |
4 | 0 | llm_graph_context_mamba(params) { |
5 | 0 | ggml_tensor * cur; |
6 | 0 | ggml_tensor * inpL; |
7 | | |
8 | | // {n_embd, n_tokens} |
9 | 0 | inpL = build_inp_embd(model.tok_embd); |
10 | 0 | cb(inpL, "embedding_output", -1); |
11 | |
|
12 | 0 | ggml_tensor * inp_pos = build_inp_pos(); |
13 | |
|
14 | 0 | auto * inp_hybrid = build_inp_mem_hybrid(); |
15 | |
|
16 | 0 | ggml_tensor * inp_out_ids = build_inp_out_ids(); |
17 | |
|
18 | 0 | for (int il = 0; il < n_layer; ++il) { |
19 | 0 | ggml_tensor * residual = inpL; |
20 | | |
21 | | // ggml_graph_add_node(gf, model.layers[il].attn_norm); |
22 | | // cb(model.layers[il].attn_norm, "attn_norm", il); |
23 | | |
24 | | // pre_mixer_norm |
25 | 0 | cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); |
26 | | |
27 | | // check if this layer is Mamba or Attention |
28 | 0 | bool is_mamba_layer = hparams.is_recurrent(il); |
29 | |
|
30 | 0 | if (is_mamba_layer) { |
31 | | // PLaMo-2 Mamba layer |
32 | 0 | cur = build_plamo2_mamba_layer(inp_hybrid->get_recr(), cur, model, ubatch, il); |
33 | 0 | } else { |
34 | | // PLaMo-2 Attention layer |
35 | 0 | cur = build_plamo2_attn_layer(inp_hybrid->get_attn(), inp_pos, cur, model, il); |
36 | 0 | } |
37 | | |
38 | | // post_mixer_norm |
39 | 0 | cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); |
40 | 0 | cb(cur, "attn_post_norm", il); |
41 | | |
42 | | // residual connection |
43 | 0 | cur = ggml_add(ctx0, cur, residual); |
44 | 0 | cb(cur, "attn_residual", il); |
45 | 0 | residual = cur; |
46 | | |
47 | | // pre-ffn norm |
48 | 0 | cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); |
49 | 0 | cb(cur, "ffn_pre_norm", il); |
50 | | |
51 | | // feed-forward network |
52 | 0 | cur = build_ffn(cur, |
53 | 0 | model.layers[il].ffn_up, NULL, NULL, |
54 | 0 | NULL, NULL, NULL, |
55 | 0 | model.layers[il].ffn_down, NULL, NULL, |
56 | 0 | NULL, LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); |
57 | 0 | cb(cur, "ffn_out", il); |
58 | | |
59 | | // post ffn norm |
60 | 0 | cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); |
61 | 0 | cb(cur, "ffn_post_norm", il); |
62 | |
|
63 | 0 | if (il == n_layer - 1 && inp_out_ids) { |
64 | 0 | cur = ggml_get_rows(ctx0, cur, inp_out_ids); |
65 | 0 | residual = ggml_get_rows(ctx0, residual, inp_out_ids); |
66 | 0 | } |
67 | | |
68 | | // residual connection |
69 | 0 | cur = ggml_add(ctx0, cur, residual); |
70 | 0 | cb(cur, "ffn_residual", il); |
71 | |
|
72 | 0 | inpL = cur; |
73 | 0 | } |
74 | |
|
75 | 0 | cur = inpL; |
76 | | |
77 | | // final norm |
78 | 0 | cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); |
79 | 0 | cb(cur, "result_norm", -1); |
80 | |
|
81 | 0 | res->t_embd = cur; |
82 | | |
83 | | // lm_head |
84 | 0 | cur = build_lora_mm(model.output, cur); |
85 | 0 | cb(cur, "result_output", -1); |
86 | | |
87 | | // Explicitly mark as output tensor to ensure proper backend assignment |
88 | 0 | ggml_set_output(cur); |
89 | |
|
90 | 0 | res->t_logits = cur; |
91 | |
|
92 | 0 | ggml_build_forward_expand(gf, cur); |
93 | 0 | } |
94 | | |
95 | | ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, |
96 | | ggml_tensor * inp_pos, |
97 | | ggml_tensor * cur, |
98 | | const llama_model & model, |
99 | 0 | int il) { |
100 | | // self-attention |
101 | 0 | { |
102 | | // PLaMo-2 uses combined QKV tensor |
103 | 0 | ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); |
104 | 0 | cb(qkv, "wqkv", il); |
105 | | |
106 | | // split QKV tensor into Q, K, V |
107 | 0 | const int64_t n_embd_head_q = hparams.n_embd_head_k; |
108 | 0 | const int64_t n_embd_head_k = hparams.n_embd_head_k; |
109 | 0 | const int64_t n_embd_head_v = hparams.n_embd_head_v; |
110 | 0 | int32_t n_head = hparams.n_head(il); |
111 | 0 | int32_t n_head_kv = hparams.n_head_kv(il); |
112 | |
|
113 | 0 | const int64_t q_offset = 0; |
114 | 0 | const int64_t k_offset = n_embd_head_q * n_head; |
115 | 0 | const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv; |
116 | |
|
117 | 0 | ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), |
118 | 0 | qkv->nb[1], q_offset * ggml_element_size(qkv)); |
119 | 0 | ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), |
120 | 0 | qkv->nb[1], k_offset * ggml_element_size(qkv)); |
121 | 0 | ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, n_embd_head_v, n_head_kv, n_tokens, n_embd_head_v * sizeof(float), |
122 | 0 | qkv->nb[1], v_offset * ggml_element_size(qkv)); |
123 | |
|
124 | 0 | cb(Qcur, "Qcur", il); |
125 | 0 | cb(Kcur, "Kcur", il); |
126 | 0 | cb(Vcur, "Vcur", il); |
127 | |
|
128 | 0 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); |
129 | 0 | cb(Qcur, "Qcur_normed", il); |
130 | |
|
131 | 0 | Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, |
132 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
133 | |
|
134 | 0 | Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); |
135 | 0 | cb(Kcur, "Kcur_normed", il); |
136 | |
|
137 | 0 | Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, |
138 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
139 | |
|
140 | 0 | cur = build_attn(inp, |
141 | 0 | model.layers[il].wo, NULL, |
142 | 0 | Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f / sqrtf(float(n_embd_head_v)), il); |
143 | 0 | } |
144 | |
|
145 | 0 | cb(cur, "attn_out", il); |
146 | |
|
147 | 0 | return cur; |
148 | 0 | } |
149 | | |
150 | | ggml_tensor * llm_build_plamo2::build_plamo2_mamba_layer(llm_graph_input_rs * inp, |
151 | | ggml_tensor * cur, |
152 | | const llama_model & model, |
153 | | const llama_ubatch & ubatch, |
154 | 0 | int il) { |
155 | 0 | const auto * mctx_cur = inp->mctx; |
156 | |
|
157 | 0 | const auto kv_head = mctx_cur->get_head(); |
158 | |
|
159 | 0 | const int64_t d_conv = hparams.ssm_d_conv; |
160 | 0 | const int64_t d_inner = hparams.ssm_d_inner; |
161 | 0 | const int64_t d_state = hparams.ssm_d_state; |
162 | 0 | const int64_t n_heads = hparams.ssm_dt_rank; |
163 | 0 | const int64_t head_dim = d_inner / n_heads; |
164 | 0 | const int64_t n_group = hparams.ssm_n_group; |
165 | 0 | const int64_t n_seqs = ubatch.n_seqs; |
166 | |
|
167 | 0 | const int64_t n_seq_tokens = ubatch.n_seq_tokens; |
168 | |
|
169 | 0 | GGML_ASSERT(n_seqs != 0); |
170 | 0 | GGML_ASSERT(ubatch.equal_seqs()); |
171 | 0 | GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); |
172 | |
|
173 | 0 | ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); |
174 | 0 | ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); |
175 | |
|
176 | 0 | ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); |
177 | 0 | conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs); |
178 | | |
179 | | // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} |
180 | 0 | cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); |
181 | | |
182 | | // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} |
183 | 0 | ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur); |
184 | 0 | cb(zx, "mamba_in_proj", il); |
185 | | // {8192, 5, 1, 1} -> {8192, 1, 5, 1} |
186 | 0 | zx = ggml_permute(ctx0, zx, 0, 2, 1, 3); |
187 | 0 | zx = ggml_cont_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); |
188 | 0 | cb(zx, "mamba_in_proj_out", il); |
189 | | |
190 | | // split into z and x |
191 | | // => {head_dim * n_heads, n_seq_tokens, n_seqs} |
192 | 0 | ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], |
193 | 0 | head_dim * ggml_element_size(zx)); |
194 | 0 | x = ggml_cont_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); |
195 | | // x = ggml_permute(ctx0, x, 0, 2, 1, 3); |
196 | 0 | cb(x, "mamba_x_split", il); |
197 | |
|
198 | 0 | ggml_tensor * z = |
199 | 0 | ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0); |
200 | 0 | cb(z, "mamba_z_split", il); |
201 | | |
202 | | // conv1d |
203 | 0 | { |
204 | | // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} |
205 | 0 | ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0); |
206 | 0 | cb(conv_x, "mamba_conv1d_input", il); |
207 | | |
208 | | // copy last (d_conv - 1) columns back into the state cache |
209 | 0 | ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], |
210 | 0 | n_seq_tokens * (conv_x->nb[0])); |
211 | |
|
212 | 0 | ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv, |
213 | 0 | ggml_view_1d(ctx0, conv_states_all, |
214 | 0 | (d_conv - 1) * (d_inner + 2 * n_group * d_state) * (n_seqs), |
215 | 0 | kv_head * (d_conv - 1) * (d_inner + 2 * n_group * d_state) * |
216 | 0 | ggml_element_size(conv_states_all)))); |
217 | 0 | cb(conv_states_all, "mamba_conv1d_state", il); |
218 | | |
219 | | // 1D convolution |
220 | 0 | x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); |
221 | 0 | cb(x, "mamba_conv1d", il); |
222 | |
|
223 | 0 | x = ggml_silu(ctx0, x); |
224 | 0 | cb(x, "mamba_conv1d_silu", il); |
225 | 0 | } |
226 | | |
227 | | // SSM |
228 | 0 | { |
229 | | // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} |
230 | 0 | ggml_tensor * x_bcdt = build_lora_mm(model.layers[il].ssm_x, x); |
231 | 0 | cb(x_bcdt, "mamba_bcdt_proj", il); |
232 | | |
233 | | // split into dt, B, C |
234 | 0 | const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); |
235 | 0 | ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0); |
236 | 0 | ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], |
237 | 0 | ggml_element_size(x_bcdt) * d_state); |
238 | 0 | ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], |
239 | 0 | ggml_element_size(x_bcdt) * (2 * d_state)); |
240 | 0 | cb(B, "mamba_B_raw", il); |
241 | 0 | cb(C, "mamba_C_raw", il); |
242 | 0 | cb(dt, "mamba_dt_raw", il); |
243 | | |
244 | | // Apply RMS norm to dt, B, C (PLaMo-2 specific) |
245 | 0 | B = build_norm(B, model.layers[il].ssm_b_norm, NULL, LLM_NORM_RMS, il); |
246 | 0 | C = build_norm(C, model.layers[il].ssm_c_norm, NULL, LLM_NORM_RMS, il); |
247 | 0 | dt = build_norm(dt, model.layers[il].ssm_dt_norm, NULL, LLM_NORM_RMS, il); |
248 | 0 | cb(B, "mamba_B_normed", il); |
249 | 0 | cb(C, "mamba_C_normed", il); |
250 | 0 | cb(dt, "mamba_dt_normed", il); |
251 | | |
252 | | // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} |
253 | 0 | dt = build_lora_mm(model.layers[il].ssm_dt, dt); |
254 | 0 | dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); |
255 | 0 | cb(dt, "mamba_dt_proj", il); |
256 | |
|
257 | 0 | ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads); |
258 | 0 | cb(A, "mamba_A", il); |
259 | |
|
260 | 0 | x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), |
261 | 0 | head_dim * n_heads * ggml_element_size(x), |
262 | 0 | head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); |
263 | 0 | B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0); |
264 | 0 | C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0); |
265 | | |
266 | | // use the states and the indices provided by build_recurrent_state |
267 | | // (this is necessary in order to properly use the states before they are overwritten, |
268 | | // while avoiding to make unnecessary copies of the states) |
269 | 0 | auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { |
270 | 0 | ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_heads, mctx_cur->get_size()); |
271 | | |
272 | | // Custom operator to optimize the parallel associative scan |
273 | | // as described in the Annex D of the Mamba paper. |
274 | | // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} |
275 | 0 | return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); |
276 | 0 | }; |
277 | |
|
278 | 0 | ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); |
279 | 0 | cb(y_ssm, "mamba_ssm_scan", il); |
280 | | |
281 | | // store last states |
282 | 0 | ggml_build_forward_expand( |
283 | 0 | gf, ggml_cpy( |
284 | 0 | ctx0, |
285 | 0 | ggml_view_1d(ctx0, y_ssm, n_heads * head_dim * d_state * n_seqs, |
286 | 0 | n_heads * head_dim * n_seq_tokens * n_seqs * ggml_element_size(y_ssm)), |
287 | 0 | ggml_view_1d(ctx0, ssm_states_all, n_heads * head_dim * d_state * n_seqs, |
288 | 0 | kv_head * n_seqs * n_heads * head_dim * d_state * ggml_element_size(ssm_states_all)))); |
289 | 0 | cb(ssm_states_all, "mamba_ssm_states", il); |
290 | |
|
291 | 0 | ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, |
292 | 0 | head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), |
293 | 0 | head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); |
294 | 0 | cb(y, "mamba_y_view", il); |
295 | | |
296 | | // Add D parameter and apply gating with z |
297 | | // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} |
298 | 0 | ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads); |
299 | 0 | y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D)); |
300 | 0 | cb(y, "mamba_y_add_d", il); |
301 | |
|
302 | 0 | y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); |
303 | 0 | cb(y, "mamba_y_swiglu_z", il); |
304 | | |
305 | | // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} |
306 | 0 | y = ggml_view_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs, y->nb[2], y->nb[3], 0); |
307 | 0 | cur = build_lora_mm(model.layers[il].ssm_out, y); |
308 | 0 | cb(cur, "mamba_out_proj", il); |
309 | 0 | } |
310 | | |
311 | | // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} |
312 | 0 | cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); |
313 | 0 | cb(cur, "mamba_out", il); |
314 | |
|
315 | 0 | return cur; |
316 | 0 | } |