/src/llama.cpp/src/models/plamo2.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | #include "llama-memory-recurrent.h" |
3 | | |
4 | 0 | void llama_model_plamo2::load_arch_hparams(llama_model_loader & ml) { |
5 | 0 | ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); |
6 | | |
7 | | // Load Mamba SSM parameters |
8 | 0 | ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); |
9 | 0 | ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); |
10 | 0 | ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); |
11 | 0 | ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); |
12 | 0 | ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); |
13 | | |
14 | | // Load attention parameters |
15 | 0 | ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false); |
16 | 0 | ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false); |
17 | |
|
18 | 0 | for (uint32_t i = 0; i < hparams.n_layer(); ++i) { |
19 | 0 | hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; |
20 | 0 | } |
21 | |
|
22 | 0 | switch (hparams.n_layer()) { |
23 | 0 | case 16: type = LLM_TYPE_1B; break; |
24 | 0 | case 32: |
25 | 0 | if (hparams.n_embd == 2048) { |
26 | 0 | type = LLM_TYPE_2B; |
27 | 0 | } else if (hparams.n_embd == 4096) { |
28 | 0 | type = LLM_TYPE_8B; |
29 | 0 | } |
30 | 0 | break; |
31 | 0 | default: type = LLM_TYPE_UNKNOWN; |
32 | 0 | } |
33 | 0 | } |
34 | | |
35 | 0 | void llama_model_plamo2::load_arch_tensors(llama_model_loader &) { |
36 | 0 | LLAMA_LOAD_LOCALS; |
37 | | |
38 | | // mamba parameters |
39 | 0 | const uint32_t d_conv = hparams.ssm_d_conv; |
40 | 0 | const uint32_t d_state = hparams.ssm_d_state; |
41 | 0 | const uint32_t num_heads = hparams.ssm_dt_rank; |
42 | 0 | const uint32_t intermediate_size = hparams.ssm_d_inner; |
43 | 0 | const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); |
44 | | |
45 | | // attention parameters |
46 | 0 | const uint32_t qk_dim = hparams.n_embd_head_k(); |
47 | 0 | const uint32_t v_dim = hparams.n_embd_head_v(); |
48 | |
|
49 | 0 | tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); |
50 | | |
51 | | // output |
52 | 0 | output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); |
53 | 0 | output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); |
54 | | // if output is NULL, init from the input tok embed |
55 | 0 | if (output == NULL) { |
56 | 0 | output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); |
57 | 0 | } |
58 | |
|
59 | 0 | for (int i = 0; i < n_layer; ++i) { |
60 | 0 | auto & layer = layers[i]; |
61 | 0 | bool is_mamba_layer = hparams.is_recr(i); |
62 | |
|
63 | 0 | layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); |
64 | |
|
65 | 0 | if (is_mamba_layer) { |
66 | 0 | layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0); |
67 | 0 | layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0); |
68 | |
|
69 | 0 | layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0); |
70 | 0 | layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0); |
71 | 0 | layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0); |
72 | |
|
73 | 0 | layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0); |
74 | 0 | layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0); |
75 | |
|
76 | 0 | layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0); |
77 | |
|
78 | 0 | layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0); |
79 | 0 | layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); |
80 | 0 | layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); |
81 | 0 | } else { |
82 | 0 | const int64_t num_attention_heads = hparams.n_head(i); |
83 | 0 | const int64_t q_num_heads = num_attention_heads; |
84 | 0 | const int64_t num_key_value_heads = hparams.n_head_kv(i); |
85 | 0 | const int64_t k_num_heads = num_key_value_heads; |
86 | 0 | const int64_t v_num_heads = num_key_value_heads; |
87 | 0 | const int64_t q_proj_dim = q_num_heads * qk_dim; |
88 | 0 | const int64_t k_proj_dim = k_num_heads * qk_dim; |
89 | 0 | const int64_t v_proj_dim = v_num_heads * v_dim; |
90 | |
|
91 | 0 | layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); |
92 | 0 | layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0); |
93 | 0 | layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0); |
94 | 0 | layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); |
95 | 0 | } |
96 | | |
97 | | // All layers have post-attention norm, FFN norm, and FFN tensors |
98 | 0 | layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); |
99 | 0 | layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); |
100 | 0 | layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); |
101 | 0 | layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); |
102 | 0 | layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); |
103 | 0 | } |
104 | 0 | } |
105 | | |
106 | 0 | std::unique_ptr<llm_graph_context> llama_model_plamo2::build_arch_graph(const llm_graph_params & params) const { |
107 | 0 | return std::make_unique<graph>(*this, params); |
108 | 0 | } |
109 | | |
110 | | llama_model_plamo2::graph::graph(const llama_model & model, const llm_graph_params & params) : |
111 | 0 | llm_build_mamba_base(params) { |
112 | 0 | ggml_tensor * cur; |
113 | 0 | ggml_tensor * inpL; |
114 | | |
115 | | // {n_embd, n_tokens} |
116 | 0 | inpL = build_inp_embd(model.tok_embd); |
117 | 0 | cb(inpL, "embedding_output", -1); |
118 | |
|
119 | 0 | ggml_tensor * inp_pos = build_inp_pos(); |
120 | |
|
121 | 0 | auto * inp_hybrid = build_inp_mem_hybrid(); |
122 | |
|
123 | 0 | ggml_tensor * inp_out_ids = build_inp_out_ids(); |
124 | |
|
125 | 0 | for (int il = 0; il < n_layer; ++il) { |
126 | 0 | ggml_tensor * residual = inpL; |
127 | | |
128 | | // ggml_graph_add_node(gf, model.layers[il].attn_norm); |
129 | | // cb(model.layers[il].attn_norm, "attn_norm", il); |
130 | | |
131 | | // pre_mixer_norm |
132 | 0 | cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); |
133 | | |
134 | | // check if this layer is Mamba or Attention |
135 | 0 | const bool is_mamba_layer = hparams.is_recr(il); |
136 | |
|
137 | 0 | if (is_mamba_layer) { |
138 | | // PLaMo-2 Mamba layer |
139 | 0 | cur = build_plamo2_mamba_layer(inp_hybrid->get_recr(), cur, model, ubatch, il); |
140 | 0 | } else { |
141 | | // PLaMo-2 Attention layer |
142 | 0 | cur = build_plamo2_attn_layer(inp_hybrid->get_attn(), inp_pos, cur, model, il); |
143 | 0 | } |
144 | | |
145 | | // post_mixer_norm |
146 | 0 | cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); |
147 | 0 | cb(cur, "attn_post_norm", il); |
148 | | |
149 | | // residual connection |
150 | 0 | cur = ggml_add(ctx0, cur, residual); |
151 | 0 | cb(cur, "attn_residual", il); |
152 | 0 | residual = cur; |
153 | | |
154 | | // pre-ffn norm |
155 | 0 | cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); |
156 | 0 | cb(cur, "ffn_pre_norm", il); |
157 | | |
158 | | // feed-forward network |
159 | 0 | cur = build_ffn(cur, |
160 | 0 | model.layers[il].ffn_up, NULL, NULL, |
161 | 0 | NULL, NULL, NULL, |
162 | 0 | model.layers[il].ffn_down, NULL, NULL, |
163 | 0 | NULL, LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); |
164 | 0 | cb(cur, "ffn_out", il); |
165 | | |
166 | | // post ffn norm |
167 | 0 | cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); |
168 | 0 | cb(cur, "ffn_post_norm", il); |
169 | |
|
170 | 0 | if (il == n_layer - 1 && inp_out_ids) { |
171 | 0 | cur = ggml_get_rows(ctx0, cur, inp_out_ids); |
172 | 0 | residual = ggml_get_rows(ctx0, residual, inp_out_ids); |
173 | 0 | } |
174 | | |
175 | | // residual connection |
176 | 0 | cur = ggml_add(ctx0, cur, residual); |
177 | 0 | cb(cur, "ffn_residual", il); |
178 | | |
179 | | // input for next layer |
180 | 0 | inpL = cur; |
181 | 0 | } |
182 | |
|
183 | 0 | cur = inpL; |
184 | | |
185 | | // final norm |
186 | 0 | cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); |
187 | 0 | cb(cur, "result_norm", -1); |
188 | |
|
189 | 0 | res->t_embd = cur; |
190 | | |
191 | | // lm_head |
192 | 0 | cur = build_lora_mm(model.output, cur, model.output_s); |
193 | 0 | cb(cur, "result_output", -1); |
194 | | |
195 | | // Explicitly mark as output tensor to ensure proper backend assignment |
196 | 0 | ggml_set_output(cur); |
197 | |
|
198 | 0 | res->t_logits = cur; |
199 | |
|
200 | 0 | ggml_build_forward_expand(gf, cur); |
201 | 0 | } |
202 | | |
203 | | ggml_tensor * llama_model_plamo2::graph::build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, |
204 | | ggml_tensor * inp_pos, |
205 | | ggml_tensor * cur, |
206 | | const llama_model & model, |
207 | 0 | int il) { |
208 | | // self-attention |
209 | 0 | { |
210 | | // PLaMo-2 uses combined QKV tensor |
211 | 0 | ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); |
212 | 0 | cb(qkv, "wqkv", il); |
213 | | |
214 | | // split QKV tensor into Q, K, V |
215 | 0 | const int64_t n_embd_head_q = hparams.n_embd_head_k(); |
216 | 0 | const int64_t n_embd_head_k = hparams.n_embd_head_k(); |
217 | 0 | const int64_t n_embd_head_v = hparams.n_embd_head_v(); |
218 | 0 | int32_t n_head = hparams.n_head(il); |
219 | 0 | int32_t n_head_kv = hparams.n_head_kv(il); |
220 | |
|
221 | 0 | const int64_t q_offset = 0; |
222 | 0 | const int64_t k_offset = n_embd_head_q * n_head; |
223 | 0 | const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv; |
224 | |
|
225 | 0 | ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), |
226 | 0 | qkv->nb[1], q_offset * ggml_element_size(qkv)); |
227 | 0 | ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), |
228 | 0 | qkv->nb[1], k_offset * ggml_element_size(qkv)); |
229 | 0 | ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, n_embd_head_v, n_head_kv, n_tokens, n_embd_head_v * sizeof(float), |
230 | 0 | qkv->nb[1], v_offset * ggml_element_size(qkv)); |
231 | |
|
232 | 0 | cb(Qcur, "Qcur", il); |
233 | 0 | cb(Kcur, "Kcur", il); |
234 | 0 | cb(Vcur, "Vcur", il); |
235 | |
|
236 | 0 | Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); |
237 | 0 | cb(Qcur, "Qcur_normed", il); |
238 | |
|
239 | 0 | Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, |
240 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
241 | |
|
242 | 0 | Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); |
243 | 0 | cb(Kcur, "Kcur_normed", il); |
244 | |
|
245 | 0 | Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, |
246 | 0 | ext_factor, attn_factor, beta_fast, beta_slow); |
247 | |
|
248 | 0 | cur = build_attn(inp, |
249 | 0 | model.layers[il].wo, NULL, model.layers[il].wo_s, |
250 | 0 | Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f / sqrtf(float(n_embd_head_v)), il); |
251 | 0 | } |
252 | |
|
253 | 0 | cb(cur, "attn_out", il); |
254 | |
|
255 | 0 | return cur; |
256 | 0 | } |
257 | | |
258 | | ggml_tensor * llama_model_plamo2::graph::build_plamo2_mamba_layer(llm_graph_input_rs * inp, |
259 | | ggml_tensor * cur, |
260 | | const llama_model & model, |
261 | | const llama_ubatch & ubatch, |
262 | 0 | int il) { |
263 | 0 | const auto * mctx_cur = inp->mctx; |
264 | |
|
265 | 0 | const auto kv_head = mctx_cur->get_head(); |
266 | |
|
267 | 0 | const int64_t d_conv = hparams.ssm_d_conv; |
268 | 0 | const int64_t d_inner = hparams.ssm_d_inner; |
269 | 0 | const int64_t d_state = hparams.ssm_d_state; |
270 | 0 | const int64_t n_heads = hparams.ssm_dt_rank; |
271 | 0 | const int64_t head_dim = d_inner / n_heads; |
272 | 0 | const int64_t n_group = hparams.ssm_n_group; |
273 | 0 | const int64_t n_seqs = ubatch.n_seqs; |
274 | |
|
275 | 0 | const int64_t n_seq_tokens = ubatch.n_seq_tokens; |
276 | |
|
277 | 0 | GGML_ASSERT(n_seqs != 0); |
278 | 0 | GGML_ASSERT(ubatch.equal_seqs()); |
279 | 0 | GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); |
280 | 0 | GGML_ASSERT(d_inner % n_heads == 0); |
281 | 0 | GGML_ASSERT(n_group == 0); |
282 | |
|
283 | 0 | ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); |
284 | 0 | ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); |
285 | |
|
286 | 0 | ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); |
287 | 0 | conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs); |
288 | | |
289 | | // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} |
290 | 0 | cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); |
291 | | |
292 | | // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} |
293 | 0 | ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur); |
294 | 0 | cb(zx, "mamba_in_proj", il); |
295 | | // {8192, 5, 1, 1} -> {8192, 1, 5, 1} |
296 | 0 | zx = ggml_permute(ctx0, zx, 0, 2, 1, 3); |
297 | 0 | zx = ggml_cont_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); |
298 | 0 | cb(zx, "mamba_in_proj_out", il); |
299 | | |
300 | | // split into z and x |
301 | | // => {head_dim * n_heads, n_seq_tokens, n_seqs} |
302 | 0 | ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], |
303 | 0 | head_dim * ggml_element_size(zx)); |
304 | 0 | x = ggml_cont_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); |
305 | | // x = ggml_permute(ctx0, x, 0, 2, 1, 3); |
306 | 0 | cb(x, "mamba_x_split", il); |
307 | |
|
308 | 0 | ggml_tensor * z = |
309 | 0 | ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0); |
310 | 0 | cb(z, "mamba_z_split", il); |
311 | | |
312 | | // conv1d |
313 | 0 | { |
314 | | // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} |
315 | 0 | ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0); |
316 | 0 | cb(conv_x, "mamba_conv1d_input", il); |
317 | | |
318 | | // copy last (d_conv - 1) columns back into the state cache |
319 | 0 | ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2], |
320 | 0 | n_seq_tokens * (conv_x->nb[0])); |
321 | |
|
322 | 0 | ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv, |
323 | 0 | ggml_view_1d(ctx0, conv_states_all, |
324 | 0 | (d_conv - 1) * (d_inner + 2 * n_group * d_state) * (n_seqs), |
325 | 0 | kv_head * (d_conv - 1) * (d_inner + 2 * n_group * d_state) * |
326 | 0 | ggml_element_size(conv_states_all)))); |
327 | 0 | cb(conv_states_all, "mamba_conv1d_state", il); |
328 | | |
329 | | // 1D convolution |
330 | 0 | x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); |
331 | 0 | cb(x, "mamba_conv1d", il); |
332 | |
|
333 | 0 | x = ggml_silu(ctx0, x); |
334 | 0 | cb(x, "mamba_conv1d_silu", il); |
335 | 0 | } |
336 | | |
337 | | // SSM |
338 | 0 | { |
339 | | // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} |
340 | 0 | ggml_tensor * x_bcdt = build_lora_mm(model.layers[il].ssm_x, x); |
341 | 0 | cb(x_bcdt, "mamba_bcdt_proj", il); |
342 | | |
343 | | // split into dt, B, C |
344 | 0 | const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); |
345 | 0 | ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0); |
346 | 0 | ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], |
347 | 0 | ggml_element_size(x_bcdt) * d_state); |
348 | 0 | ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], |
349 | 0 | ggml_element_size(x_bcdt) * (2 * d_state)); |
350 | 0 | cb(B, "mamba_B_raw", il); |
351 | 0 | cb(C, "mamba_C_raw", il); |
352 | 0 | cb(dt, "mamba_dt_raw", il); |
353 | | |
354 | | // Apply RMS norm to dt, B, C (PLaMo-2 specific) |
355 | 0 | B = build_norm(B, model.layers[il].ssm_b_norm, NULL, LLM_NORM_RMS, il); |
356 | 0 | C = build_norm(C, model.layers[il].ssm_c_norm, NULL, LLM_NORM_RMS, il); |
357 | 0 | dt = build_norm(dt, model.layers[il].ssm_dt_norm, NULL, LLM_NORM_RMS, il); |
358 | 0 | cb(B, "mamba_B_normed", il); |
359 | 0 | cb(C, "mamba_C_normed", il); |
360 | 0 | cb(dt, "mamba_dt_normed", il); |
361 | | |
362 | | // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} |
363 | 0 | dt = build_lora_mm(model.layers[il].ssm_dt, dt); |
364 | 0 | dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); |
365 | 0 | cb(dt, "mamba_dt_proj", il); |
366 | |
|
367 | 0 | ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads); |
368 | 0 | cb(A, "mamba_A", il); |
369 | |
|
370 | 0 | x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), |
371 | 0 | head_dim * n_heads * ggml_element_size(x), |
372 | 0 | head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); |
373 | 0 | B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0); |
374 | 0 | C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0); |
375 | | |
376 | | // use the states and the indices provided by build_recurrent_state |
377 | | // (this is necessary in order to properly use the states before they are overwritten, |
378 | | // while avoiding to make unnecessary copies of the states) |
379 | 0 | auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { |
380 | 0 | ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_heads, mctx_cur->get_size()); |
381 | | |
382 | | // Custom operator to optimize the parallel associative scan |
383 | | // as described in the Annex D of the Mamba paper. |
384 | | // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} |
385 | 0 | return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); |
386 | 0 | }; |
387 | |
|
388 | 0 | ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); |
389 | 0 | cb(y_ssm, "mamba_ssm_scan", il); |
390 | | |
391 | | // store last states |
392 | 0 | ggml_build_forward_expand( |
393 | 0 | gf, ggml_cpy( |
394 | 0 | ctx0, |
395 | 0 | ggml_view_1d(ctx0, y_ssm, n_heads * head_dim * d_state * n_seqs, |
396 | 0 | n_heads * head_dim * n_seq_tokens * n_seqs * ggml_element_size(y_ssm)), |
397 | 0 | ggml_view_1d(ctx0, ssm_states_all, n_heads * head_dim * d_state * n_seqs, |
398 | 0 | kv_head * n_seqs * n_heads * head_dim * d_state * ggml_element_size(ssm_states_all)))); |
399 | 0 | cb(ssm_states_all, "mamba_ssm_states", il); |
400 | |
|
401 | 0 | ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, |
402 | 0 | head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), |
403 | 0 | head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); |
404 | 0 | cb(y, "mamba_y_view", il); |
405 | | |
406 | | // Add D parameter and apply gating with z |
407 | | // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} |
408 | 0 | ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads); |
409 | 0 | y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D)); |
410 | 0 | cb(y, "mamba_y_add_d", il); |
411 | |
|
412 | 0 | y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); |
413 | 0 | cb(y, "mamba_y_swiglu_z", il); |
414 | | |
415 | | // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} |
416 | 0 | y = ggml_view_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs, y->nb[2], y->nb[3], 0); |
417 | 0 | cur = build_lora_mm(model.layers[il].ssm_out, y); |
418 | 0 | cb(cur, "mamba_out_proj", il); |
419 | 0 | } |
420 | | |
421 | | // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} |
422 | 0 | cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); |
423 | 0 | cb(cur, "mamba_out", il); |
424 | |
|
425 | 0 | return cur; |
426 | 0 | } |