/src/llama.cpp/src/models/kimi-linear.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | |
3 | | #include "llama-memory-recurrent.h" |
4 | | |
5 | | // Causal Conv1d function for Q,K,V |
6 | | // When qkv is 0, it is Q, 1 is K, 2 is V |
7 | 0 | static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) { |
8 | 0 | const int64_t d_inner = head_dim * n_head; |
9 | 0 | const int64_t conv_state_size = (d_conv - 1) * d_inner; |
10 | 0 | const int64_t n_embd_r_total = 3 * conv_state_size; // Q + K + V |
11 | | |
12 | | // conv_state_all is [n_embd_r_total, n_seqs], split into Q, K, V |
13 | | // Each conv state is [(d_conv-1) * d_inner] per sequence, need to reshape to [d_conv-1, d_inner, n_seqs] |
14 | | // Memory layout: for each seq, Q state is first conv_state_size elements, then K, then V |
15 | | // conv_state_all has stride: nb[0] = element_size, nb[1] = n_embd_r_total * element_size |
16 | | // View Q conv state: offset 0, size conv_state_size per seq |
17 | | // conv_state_all is [n_embd_r_total, n_seqs] with memory layout: |
18 | | // state[i + seq * n_embd_r_total] where i = conv_step + channel * (d_conv-1) + {0, conv_state_size, 2*conv_state_size} for Q/K/V |
19 | | // We want [d_conv-1, d_inner, n_seqs] view: |
20 | | // nb1 = (d_conv-1) * element_size (stride between channels) |
21 | | // nb2 = n_embd_r_total * element_size (stride between seqs) |
22 | 0 | ggml_tensor * conv_state_x = ggml_view_3d(ctx0, conv_state_all, d_conv - 1, d_inner, n_seqs, |
23 | 0 | (d_conv - 1) * ggml_element_size(conv_state_all), // nb1: stride between channels |
24 | 0 | n_embd_r_total * ggml_element_size(conv_state_all), // nb2: stride between seqs |
25 | 0 | qkv * conv_state_size * ggml_element_size(conv_state_all)); |
26 | | |
27 | | // Causal Conv1d function for Q,K,V |
28 | | // When qkv is 0, it is Q, 1 is K, 2 is V |
29 | | // Step 1: Q, K, V projections -> [d_inner, n_tokens] |
30 | 0 | ggml_tensor * x_proj = ggml_mul_mat(ctx0, proj_w, x); |
31 | | |
32 | | // Reshape input: {d_inner, n_tokens} -> {d_inner, n_seq_tokens, n_seqs} |
33 | 0 | ggml_tensor * x_3d = ggml_reshape_3d(ctx0, x_proj, d_inner, n_seq_tokens, n_seqs); |
34 | | |
35 | | // Concat Q conv state and current input: {d_conv-1 + n_seq_tokens, d_inner, n_seqs} |
36 | 0 | ggml_tensor * conv_x = ggml_concat(ctx0, conv_state_x, ggml_transpose(ctx0, x_3d), 0); |
37 | | |
38 | | // Save last (d_conv-1) columns back to Q conv state |
39 | 0 | ggml_tensor * last_conv_x = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, |
40 | 0 | conv_x->nb[1], conv_x->nb[2], n_seq_tokens * conv_x->nb[0]); |
41 | 0 | ggml_build_forward_expand(gf, |
42 | 0 | ggml_cpy(ctx0, last_conv_x, |
43 | 0 | ggml_view_3d(ctx0, conv_states_all, |
44 | 0 | d_conv - 1, d_inner, n_seqs, |
45 | 0 | (d_conv - 1) * ggml_element_size(conv_states_all), // nb1: contiguous within one channel's conv taps |
46 | 0 | n_embd_r_total * ggml_element_size(conv_states_all), // nb2: stride between sequences (skip over K,V states) |
47 | 0 | (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all)))); // offset to first seq's Q/K/V state |
48 | | // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner] |
49 | | // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv] |
50 | | // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step] |
51 | | // ggml_ssm_conv computes: c[conv_step + channel * d_conv] |
52 | | // GGUF layout: [d_conv, 1, d_inner] or [d_conv, 1, d_inner, 1] -> reshape to [d_conv, d_inner] |
53 | | // Reshape conv weight from [d_conv, 1, d_inner, 1] to [d_conv, d_inner] for ggml_ssm_conv |
54 | 0 | ggml_tensor * conv_weight = ggml_reshape_2d(ctx0, conv_w, d_conv, d_inner); |
55 | | |
56 | | // Apply conv1d |
57 | | // ggml_ssm_conv output: {d_inner, n_seq_tokens, n_seqs} |
58 | 0 | ggml_tensor * Xcur = ggml_ssm_conv(ctx0, conv_x, conv_weight); |
59 | | // Reshape to 2D for bias add: {d_inner, n_tokens} |
60 | 0 | Xcur = ggml_reshape_2d(ctx0, Xcur, d_inner, n_tokens); |
61 | 0 | Xcur = ggml_silu(ctx0, Xcur); |
62 | |
|
63 | 0 | return ggml_reshape_4d(ctx0, Xcur, head_dim, n_head, n_seq_tokens, n_seqs); |
64 | 0 | } |
65 | | |
66 | | llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) : |
67 | 0 | llm_build_delta_net_base(params), model(model) { |
68 | 0 | ggml_tensor * cur; |
69 | 0 | ggml_tensor * inpL; |
70 | |
|
71 | 0 | inpL = build_inp_embd(model.tok_embd); |
72 | 0 | cb(inpL, "model.embed_tokens", -1); |
73 | | |
74 | | // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM) |
75 | | // So we don't need inp_pos |
76 | |
|
77 | 0 | auto * inp_kv = !hparams.is_mla() ? build_inp_mem_hybrid() : nullptr; |
78 | 0 | auto * inp_k = hparams.is_mla() ? build_inp_mem_hybrid_k() : nullptr; |
79 | 0 | auto * inp_rs = hparams.is_mla() ? inp_k->get_recr() : inp_kv->get_recr(); |
80 | 0 | auto * inp_attn_kv = !hparams.is_mla() ? inp_kv->get_attn() : nullptr; |
81 | 0 | auto * inp_attn_k = hparams.is_mla() ? inp_k->get_attn() : nullptr; |
82 | | |
83 | | // Output ids for selecting which tokens to output |
84 | 0 | ggml_tensor * inp_out_ids = build_inp_out_ids(); |
85 | | |
86 | | // Kimi dimension constants |
87 | 0 | const int64_t n_head = hparams.n_head(); |
88 | 0 | const int64_t head_dim = hparams.n_embd_head_kda; |
89 | 0 | const int64_t d_conv = hparams.ssm_d_conv; |
90 | 0 | const int64_t d_inner = n_head * head_dim; // 32 * 128 = 4096 |
91 | 0 | const int64_t n_seqs = ubatch.n_seqs; |
92 | 0 | const int64_t n_seq_tokens = ubatch.n_seq_tokens; |
93 | | |
94 | | // Verify batch consistency for recurrent layers |
95 | 0 | GGML_ASSERT(n_seqs != 0); |
96 | 0 | GGML_ASSERT(ubatch.equal_seqs()); |
97 | 0 | GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); |
98 | | |
99 | | // MLA params |
100 | 0 | const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); |
101 | 0 | const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); |
102 | 0 | const int64_t kv_lora_rank = hparams.n_lora_kv; |
103 | | // qk_rope_head_dim = 64 (from Kimi config) which is hparams.n_rot |
104 | | // Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim] |
105 | 0 | const int64_t n_embd_head_qk_rope = hparams.n_rot(); // config.qk_rope_head_dim |
106 | 0 | const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; // 192 - 64 = 128 |
107 | | // Attention scale for MLA |
108 | 0 | const float kq_scale_mla = 1.0f / sqrtf((float)n_embd_head_k_mla); |
109 | |
|
110 | 0 | for (int il = 0; il < n_layer; ++il) { |
111 | 0 | const auto & layer = model.layers[il]; |
112 | 0 | ggml_tensor * inpSA = inpL; |
113 | | |
114 | | // Attention Norm |
115 | 0 | cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il); |
116 | 0 | cb(cur, "attn_norm", il); |
117 | |
|
118 | 0 | ggml_build_forward_expand(gf, cur); |
119 | |
|
120 | 0 | if (hparams.is_recurrent(il)) { |
121 | | // === KDA Layer (Kimi Delta Attention) with Recurrent State === |
122 | | // Reference: vLLM kda.py |
123 | 0 | const auto * mctx_cur = inp_rs->mctx; |
124 | 0 | const auto kv_head = mctx_cur->get_head(); |
125 | | |
126 | | // Get conv states from r_l tensor (Q, K, V each have separate state) |
127 | 0 | ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); |
128 | 0 | cb(conv_states_all, "conv_states_all", il); |
129 | 0 | ggml_tensor * conv_state_all = build_rs(inp_rs, conv_states_all, hparams.n_embd_r(), n_seqs); |
130 | 0 | ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, cur, layer.wq, layer.ssm_q_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); |
131 | 0 | ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, cur, layer.wk, layer.ssm_k_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); |
132 | 0 | ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, cur, layer.wv, layer.ssm_v_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); |
133 | | |
134 | | // g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias) |
135 | 0 | ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur); |
136 | 0 | ggml_tensor * g1 = ggml_mul_mat(ctx0, layer.ssm_f_b, f_a); |
137 | 0 | cb(g1, "g1 f_b(f_a(cur))", il); |
138 | 0 | g1 = ggml_add(ctx0, g1, layer.ssm_dt_b); |
139 | 0 | g1 = ggml_softplus(ctx0, g1); |
140 | 0 | g1 = ggml_reshape_3d(ctx0, g1, head_dim, n_head, n_tokens); |
141 | | |
142 | | // A_log shape is [1, n_head] or [1, n_head, 1, 1], need to broadcast to [head_dim, n_head, n_tokens]. No need to -exp(a_log) because it was done in convert_hf_to_gguf.py |
143 | | // Reshape to [1, n_head, 1] for broadcasting with g1 [head_dim, n_head, n_tokens] |
144 | 0 | ggml_tensor * A = ggml_reshape_3d(ctx0, layer.ssm_a, 1, n_head, 1); |
145 | 0 | g1 = ggml_mul(ctx0, g1, A); |
146 | 0 | cb(g1, "kda_g1", il); |
147 | |
|
148 | 0 | g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs); |
149 | | |
150 | | // Compute beta (mixing coefficient) |
151 | 0 | ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur); |
152 | 0 | beta = ggml_reshape_4d(ctx0, beta, 1, n_head, n_seq_tokens, n_seqs); |
153 | 0 | cb(beta, "kda_beta", il); |
154 | |
|
155 | 0 | beta = ggml_sigmoid(ctx0, beta); |
156 | | |
157 | | // Reshape for KDA recurrence |
158 | | // {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs} |
159 | 0 | cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); |
160 | | |
161 | | // Get SSM state and compute KDA recurrence using ggml_kda_scan |
162 | 0 | ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); |
163 | 0 | ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs); |
164 | 0 | state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs); |
165 | |
|
166 | 0 | const float eps_norm = hparams.f_norm_rms_eps; |
167 | |
|
168 | 0 | Qcur = ggml_l2_norm(ctx0, Qcur, eps_norm); |
169 | 0 | Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm); |
170 | | |
171 | | // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens |
172 | 0 | auto attn_out = build_delta_net(Qcur, Kcur, Vcur, g1, beta, state, il); |
173 | |
|
174 | 0 | ggml_tensor * output = ggml_cont(ctx0, attn_out.first); |
175 | 0 | ggml_tensor * new_state = attn_out.second; |
176 | 0 | cb(output, "attn_output", il); |
177 | 0 | cb(new_state, "new_state", il); |
178 | | |
179 | | // Update the recurrent states |
180 | 0 | ggml_build_forward_expand(gf, |
181 | 0 | ggml_cpy(ctx0, new_state, |
182 | 0 | ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, |
183 | 0 | kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); |
184 | | |
185 | | // Output gating g2 = g_b(g_a(x)) |
186 | 0 | ggml_tensor * cur_2d = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); |
187 | 0 | ggml_tensor * g_a = ggml_mul_mat(ctx0, layer.ssm_g_a, cur_2d); |
188 | 0 | ggml_tensor * g2 = ggml_mul_mat(ctx0, layer.ssm_g_b, g_a); |
189 | 0 | cb(g2, "g2 g_b(g_a(cur_2d))", il); |
190 | 0 | g2 = ggml_reshape_3d(ctx0, g2, head_dim, n_head, n_seq_tokens * n_seqs); |
191 | | |
192 | | // Apply o_norm with sigmoid gating |
193 | | // Note: Kimi model uses sigmoid gating, not SiLU (despite FusedRMSNormGated default being swish) |
194 | | // Formula: output = RMSNorm(x) * sigmoid(g) |
195 | 0 | ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, output, head_dim, n_head, n_seq_tokens * n_seqs); |
196 | 0 | ggml_tensor * normed = build_norm(attn_out_final, layer.ssm_o_norm, nullptr, LLM_NORM_RMS, il); |
197 | 0 | cb(normed, "kda_normed", il); |
198 | 0 | ggml_tensor * gate = ggml_sigmoid(ctx0, g2); |
199 | 0 | ggml_tensor * gated = ggml_mul(ctx0, normed, gate); |
200 | | |
201 | | // Output projection |
202 | 0 | gated = ggml_cont_2d(ctx0, gated, d_inner, n_tokens); |
203 | 0 | cur = ggml_mul_mat(ctx0, layer.wo, gated); |
204 | 0 | cb(cur, "kda_out", il); |
205 | |
|
206 | 0 | } else { |
207 | | // === MLA Layer (Multi-head Latent Attention) without KV Cache === |
208 | | // Reference: vLLM mla.py |
209 | | // Step 1: Q projection and reshape |
210 | | // vLLM Kimi: q = q_proj(hidden_states), then view as [n_tokens, n_head, qk_head_dim] |
211 | | // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM) |
212 | 0 | ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.wq, cur); |
213 | | |
214 | | // Step 2: KV compression |
215 | | // kv_cmpr_pe = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens] |
216 | 0 | ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur); |
217 | | |
218 | | // Split: kv_cmpr = kv_lora[:kv_lora_rank], k_pe = kv_lora[kv_lora_rank:] |
219 | 0 | ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens, |
220 | 0 | ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); |
221 | 0 | ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens, |
222 | 0 | ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), |
223 | 0 | ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), |
224 | 0 | ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); |
225 | | // Note: Kimi MLA does NOT apply RoPE (rotary_emb=None in vLLM) |
226 | | // k_pe is used directly without RoPE |
227 | | // Normalize kv_c |
228 | 0 | kv_cmpr = build_norm(kv_cmpr, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il); |
229 | |
|
230 | 0 | if (layer.wk_b && layer.wv_b) { // MLA KV cache enabled |
231 | | // extract q_nope |
232 | 0 | ggml_tensor * q_nope = |
233 | 0 | ggml_view_3d(ctx0, Qcur, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla), |
234 | 0 | ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, 0); |
235 | 0 | cb(q_nope, "q_nope", il); |
236 | | |
237 | | // and {n_embd_head_qk_rope, n_head, n_tokens} |
238 | 0 | ggml_tensor * q_pe = ggml_view_3d( |
239 | 0 | ctx0, Qcur, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla), |
240 | 0 | ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, ggml_row_size(Qcur->type, n_embd_head_qk_nope)); |
241 | 0 | cb(q_pe, "q_pe", il); |
242 | | |
243 | | // {n_embd_head_qk_nope, n_tokens, n_head} |
244 | 0 | q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); |
245 | 0 | cb(q_nope, "q_nope_perm", il); |
246 | | |
247 | | // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head} |
248 | 0 | ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, layer.wk_b, q_nope); |
249 | 0 | cb(q_nope_absorbed, "q_nope_absorbed", il); |
250 | | |
251 | | // {kv_lora_rank, n_head, n_tokens} |
252 | 0 | q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); |
253 | 0 | cb(q_nope_absorbed, "q_nope_absorbed_perm", il); |
254 | | |
255 | | // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} |
256 | | // note: rope must go first for in-place context shifting in build_rope_shift() |
257 | 0 | Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); |
258 | 0 | cb(Qcur, "Qcur", il); |
259 | |
|
260 | 0 | kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); |
261 | 0 | cb(kv_cmpr, "kv_cmpr_reshape", il); |
262 | | |
263 | | // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} |
264 | 0 | ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0); |
265 | 0 | cb(Kcur, "Kcur", il); |
266 | | |
267 | | // {kv_lora_rank, 1, n_tokens} |
268 | 0 | ggml_tensor * Vcur = kv_cmpr; |
269 | 0 | cb(Vcur, "Vcur", il); |
270 | |
|
271 | 0 | cur = build_attn(inp_attn_k, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il); |
272 | 0 | cb(cur, "mla_out", il); |
273 | 0 | } else { // MLA KV cache disabled. Fall back to MHA KV cache. |
274 | 0 | Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens); |
275 | 0 | cb(Qcur, "mla_Q", il); |
276 | | // KV decompression: kv = kv_b_proj(kv_c_normed) |
277 | 0 | ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr); |
278 | 0 | const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla; |
279 | | |
280 | | // Split kv into k_nope and v |
281 | 0 | ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, |
282 | 0 | ggml_row_size(kv->type, kv_per_head), |
283 | 0 | ggml_row_size(kv->type, kv_per_head * n_head), 0); |
284 | 0 | ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens, |
285 | 0 | ggml_row_size(kv->type, kv_per_head), |
286 | 0 | ggml_row_size(kv->type, kv_per_head * n_head), |
287 | 0 | ggml_row_size(kv->type, n_embd_head_qk_nope)); |
288 | 0 | Vcur = ggml_cont(ctx0, Vcur); |
289 | 0 | cb(Vcur, "mla_V", il); |
290 | | |
291 | | // Concatenate k_nope + k_pe (broadcast k_pe to all heads) |
292 | | // K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens] |
293 | | // and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads |
294 | | // Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens] |
295 | 0 | ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens); |
296 | 0 | ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target); |
297 | 0 | ggml_tensor * Kcur = ggml_concat(ctx0, k_pe_repeated, k_nope, 0); |
298 | 0 | cb(Kcur, "mla_K", il); |
299 | | |
300 | | // Direct softmax attention (with MHA KV cache) |
301 | | // Use build_attn with inp_attn for proper mask handling |
302 | 0 | cur = build_attn(inp_attn_kv, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il); |
303 | 0 | cb(cur, "mla_out", il); |
304 | 0 | } |
305 | 0 | } |
306 | | |
307 | | // On last layer, select only the output tokens |
308 | 0 | if (il == n_layer - 1 && inp_out_ids) { |
309 | 0 | cur = ggml_get_rows(ctx0, cur, inp_out_ids); |
310 | 0 | inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); |
311 | 0 | } |
312 | | |
313 | | // Residual |
314 | 0 | ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); |
315 | 0 | cb(ffn_inp, "ffn_inp", il); |
316 | | |
317 | | // FFN Norm |
318 | 0 | cur = build_norm(ffn_inp, layer.ffn_norm, NULL, LLM_NORM_RMS, il); |
319 | 0 | cb(cur, "ffn_norm", il); |
320 | |
|
321 | 0 | if ((uint32_t) il < hparams.n_layer_dense_lead) { |
322 | | // Dense FFN layer |
323 | 0 | cur = build_ffn(cur, |
324 | 0 | layer.ffn_up, NULL, NULL, |
325 | 0 | layer.ffn_gate, NULL, NULL, |
326 | 0 | layer.ffn_down, NULL, NULL, |
327 | 0 | NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); |
328 | 0 | cb(cur, "ffn_out", il); |
329 | 0 | } else { |
330 | | // MoE layer |
331 | | // Kimi uses moe_renormalize=True and routed_scaling_factor (stored as expert_weights_scale) = 2.446 |
332 | 0 | ggml_tensor * moe_out = build_moe_ffn(cur, |
333 | 0 | layer.ffn_gate_inp, |
334 | 0 | layer.ffn_up_exps, |
335 | 0 | layer.ffn_gate_exps, |
336 | 0 | layer.ffn_down_exps, |
337 | 0 | layer.ffn_exp_probs_b, |
338 | 0 | hparams.n_expert, |
339 | 0 | hparams.n_expert_used, |
340 | 0 | LLM_FFN_SILU, true, |
341 | 0 | hparams.expert_weights_scale, |
342 | 0 | (llama_expert_gating_func_type) hparams.expert_gating_func, |
343 | 0 | il); |
344 | 0 | cb(moe_out, "ffn_moe_out", il); |
345 | | |
346 | | // Shared expert |
347 | 0 | { |
348 | 0 | ggml_tensor * ffn_shexp = build_ffn(cur, |
349 | 0 | layer.ffn_up_shexp, NULL, NULL, |
350 | 0 | layer.ffn_gate_shexp, NULL, NULL, |
351 | 0 | layer.ffn_down_shexp, NULL, NULL, |
352 | 0 | NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); |
353 | 0 | cb(ffn_shexp, "ffn_shexp", il); |
354 | |
|
355 | 0 | cur = ggml_add(ctx0, moe_out, ffn_shexp); |
356 | 0 | cb(cur, "ffn_out", il); |
357 | 0 | } |
358 | 0 | } |
359 | | // Residual |
360 | 0 | cur = ggml_add(ctx0, cur, ffn_inp); |
361 | |
|
362 | 0 | cur = build_cvec(cur, il); |
363 | 0 | cb(cur, "l_out", il); |
364 | | |
365 | | // input for next layer |
366 | 0 | inpL = cur; |
367 | 0 | } |
368 | 0 | cur = inpL; |
369 | | |
370 | | // Final Norm |
371 | 0 | cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); |
372 | |
|
373 | 0 | cb(cur, "result_norm", -1); |
374 | 0 | res->t_embd = cur; |
375 | | |
376 | | // Output |
377 | 0 | cur = ggml_mul_mat(ctx0, model.output, cur); |
378 | 0 | cb(cur, "result_output", -1); |
379 | 0 | res->t_logits = cur; |
380 | |
|
381 | 0 | ggml_build_forward_expand(gf, cur); |
382 | 0 | } |