/src/llama.cpp/src/models/kimi-linear.cpp
Line | Count | Source |
1 | | #include "models.h" |
2 | | #include "llama-memory-recurrent.h" |
3 | | |
4 | 0 | void llama_model_kimi_linear::load_arch_hparams(llama_model_loader & ml) { |
5 | 0 | ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); |
6 | 0 | ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); |
7 | 0 | ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); |
8 | 0 | ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); |
9 | 0 | ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); |
10 | 0 | ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); |
11 | | |
12 | | // MLA qk_rope_head_dim (for reference) |
13 | | // qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192 |
14 | | |
15 | | // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) |
16 | | // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) |
17 | 0 | for (uint32_t i = 0; i < hparams.n_layer(); ++i) { |
18 | 0 | hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent |
19 | 0 | } |
20 | | |
21 | | // MoE parameters - Kimi uses moe_intermediate_size = 1024 |
22 | 0 | ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); |
23 | 0 | ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); |
24 | 0 | ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); |
25 | 0 | ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); |
26 | 0 | ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); |
27 | |
|
28 | 0 | switch (hparams.n_layer()) { |
29 | 0 | case 27: type = LLM_TYPE_48B_A3B; break; // Kimi-Linear-48B-A3B |
30 | 0 | default: type = LLM_TYPE_UNKNOWN; |
31 | 0 | } |
32 | 0 | } |
33 | | |
34 | 0 | void llama_model_kimi_linear::load_arch_tensors(llama_model_loader &) { |
35 | 0 | LLAMA_LOAD_LOCALS; |
36 | |
|
37 | 0 | tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); |
38 | | |
39 | | // output |
40 | 0 | output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); |
41 | 0 | output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); |
42 | |
|
43 | 0 | for (int i = 0; i < n_layer; ++i) { |
44 | 0 | auto & layer = layers[i]; |
45 | |
|
46 | 0 | layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); |
47 | | |
48 | | // Check for KDA specific tensors to determine layer type or if it's a mixed model |
49 | | // Assuming KDA layer if KDA tensors are present |
50 | | |
51 | | // KDA uses head_dim = 128 (from linear_attn_config.head_dim) |
52 | 0 | const int64_t n_embd_head_k_kda = hparams.n_embd_head_kda; |
53 | 0 | const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; |
54 | 0 | const int64_t ssm_d_conv = hparams.ssm_d_conv; |
55 | |
|
56 | 0 | if (hparams.is_recr(i)) { |
57 | | // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) |
58 | | // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] |
59 | 0 | layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); |
60 | 0 | if (!layer.ssm_q_conv) { |
61 | 0 | layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); |
62 | 0 | } |
63 | | |
64 | | // KDA Layer - Conv1d weights may be 3D or 4D |
65 | 0 | layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); |
66 | 0 | if (!layer.ssm_k_conv) { |
67 | 0 | layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); |
68 | 0 | } |
69 | 0 | layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head, 1}, TENSOR_NOT_REQUIRED); |
70 | 0 | if (!layer.ssm_v_conv) { |
71 | 0 | layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head}, 0); |
72 | 0 | } |
73 | | |
74 | | // q, k, v projections |
75 | | // Python: q_proj, k_proj, v_proj |
76 | 0 | create_tensor_qkv(layer, i, n_embd, n_embd_head_k_kda * n_head, n_embd_head_k_kda * n_head, n_embd_head_v_kda * n_head, 0); |
77 | | |
78 | | // KDA specific projections |
79 | | // f_a_proj, f_b_proj |
80 | 0 | layer.ssm_f_a = create_tensor(tn(LLM_TENSOR_SSM_F_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); // head_dim |
81 | 0 | layer.ssm_f_b = create_tensor(tn(LLM_TENSOR_SSM_F_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); // projection_size |
82 | | |
83 | | // b_proj (beta mixing coefficient) |
84 | 0 | layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), {n_embd, n_head}, 0); |
85 | | |
86 | | // A_log - Shape in GGUF: [1, num_heads, 1, 1] (4D) or [1, num_heads] (2D after quantization) Note: -exp(A_log) is applied in convert_hf_to_gguf.py |
87 | 0 | layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head, 1, 1}, TENSOR_NOT_REQUIRED); |
88 | 0 | if (!layer.ssm_a) { |
89 | 0 | layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); |
90 | 0 | } |
91 | | |
92 | | // dt_bias - shape [n_embd_head_k_kda * n_head] = [4096] |
93 | 0 | layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_embd_head_k_kda * n_head}, 0); |
94 | | |
95 | | // g_a_proj, g_b_proj (output gate) |
96 | 0 | layer.ssm_g_a = create_tensor(tn(LLM_TENSOR_SSM_G_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); |
97 | 0 | layer.ssm_g_b = create_tensor(tn(LLM_TENSOR_SSM_G_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); |
98 | | |
99 | | // o_norm (reusing SSM_NORM) |
100 | 0 | layer.ssm_o_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {n_embd_head_k_kda}, 0); // FusedRMSNormGated |
101 | | |
102 | | // o_proj |
103 | 0 | layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v_kda * n_head, n_embd}, 0); |
104 | |
|
105 | 0 | } else { |
106 | | // MLA Layer - use MLA-specific head dimensions |
107 | 0 | const int64_t q_lora_rank = hparams.n_lora_q; |
108 | 0 | const int64_t kv_lora_rank = hparams.n_lora_kv; |
109 | 0 | const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); |
110 | 0 | const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); |
111 | |
|
112 | 0 | layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, TENSOR_NOT_REQUIRED); |
113 | 0 | layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); |
114 | |
|
115 | 0 | if (layer.attn_q_a_norm) { |
116 | 0 | layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); |
117 | 0 | layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); |
118 | 0 | } else { |
119 | | // Kimi MLA without Q compression: wq = [n_embd, n_head * n_embd_head_k_mla] |
120 | 0 | layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); |
121 | 0 | } |
122 | | |
123 | | // Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA) |
124 | | // Note: hparams.n_rot may be 72 (from conversion) but actual is 64 |
125 | 0 | const int64_t qk_rope_head_dim = hparams.n_rot(); // From config: qk_rope_head_dim |
126 | 0 | layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0); |
127 | | // Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled) |
128 | 0 | layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), |
129 | 0 | {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); |
130 | 0 | if (!layer.wkv_b) { // MLA KV cache enabled |
131 | 0 | layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_k_mla - qk_rope_head_dim, kv_lora_rank, n_head}, 0); |
132 | 0 | layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); |
133 | 0 | } |
134 | 0 | layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); |
135 | 0 | } |
136 | |
|
137 | 0 | layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); |
138 | | |
139 | | // MoE intermediate size (different from dense FFN) |
140 | 0 | const int64_t n_ff_exp = hparams.n_ff_exp; |
141 | | |
142 | | // Kimi uses n_layer_dense_lead to determine which layers use dense FFN vs MoE |
143 | | // first_k_dense_replace = 1 means layer 0 uses dense FFN, layers 1+ use MoE |
144 | 0 | if (i < (int) hparams.n_layer_dense_lead) { |
145 | | // Dense FFN layer - use normal n_ff |
146 | 0 | layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); |
147 | 0 | layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); |
148 | 0 | layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); |
149 | 0 | } else { |
150 | | // MoE layer - use n_ff_exp (1024) instead of n_ff (9216) |
151 | 0 | layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); |
152 | 0 | layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); |
153 | 0 | layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); |
154 | 0 | layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); |
155 | | |
156 | | // Shared experts use moe_intermediate_size * num_shared_experts |
157 | | // Kimi: shared_expert_intermediate_size = 1024 * 1 = 1024 |
158 | | // Tensors are 2D: [n_embd, n_ff_shexp] or [n_ff_shexp, n_embd] |
159 | 0 | const int64_t n_ff_shexp_actual = n_ff_exp * (hparams.n_expert_shared > 0 ? hparams.n_expert_shared : 1); |
160 | 0 | layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); |
161 | 0 | layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp_actual, n_embd}, TENSOR_NOT_REQUIRED); |
162 | 0 | layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); |
163 | |
|
164 | 0 | layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); |
165 | 0 | } |
166 | 0 | } |
167 | 0 | } |
168 | | |
169 | 0 | std::unique_ptr<llm_graph_context> llama_model_kimi_linear::build_arch_graph(const llm_graph_params & params) const { |
170 | 0 | return std::make_unique<graph>(*this, params); |
171 | 0 | } |
172 | | |
173 | | // Causal Conv1d function for Q,K,V |
174 | | // When qkv is 0, it is Q, 1 is K, 2 is V |
175 | 0 | static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) { |
176 | 0 | const int64_t d_inner = head_dim * n_head; |
177 | 0 | const int64_t conv_state_size = (d_conv - 1) * d_inner; |
178 | 0 | const int64_t n_embd_r_total = 3 * conv_state_size; // Q + K + V |
179 | | |
180 | | // conv_state_all is [n_embd_r_total, n_seqs], split into Q, K, V |
181 | | // Each conv state is [(d_conv-1) * d_inner] per sequence, need to reshape to [d_conv-1, d_inner, n_seqs] |
182 | | // Memory layout: for each seq, Q state is first conv_state_size elements, then K, then V |
183 | | // conv_state_all has stride: nb[0] = element_size, nb[1] = n_embd_r_total * element_size |
184 | | // View Q conv state: offset 0, size conv_state_size per seq |
185 | | // conv_state_all is [n_embd_r_total, n_seqs] with memory layout: |
186 | | // state[i + seq * n_embd_r_total] where i = conv_step + channel * (d_conv-1) + {0, conv_state_size, 2*conv_state_size} for Q/K/V |
187 | | // We want [d_conv-1, d_inner, n_seqs] view: |
188 | | // nb1 = (d_conv-1) * element_size (stride between channels) |
189 | | // nb2 = n_embd_r_total * element_size (stride between seqs) |
190 | 0 | ggml_tensor * conv_state_x = ggml_view_3d(ctx0, conv_state_all, d_conv - 1, d_inner, n_seqs, |
191 | 0 | (d_conv - 1) * ggml_element_size(conv_state_all), // nb1: stride between channels |
192 | 0 | n_embd_r_total * ggml_element_size(conv_state_all), // nb2: stride between seqs |
193 | 0 | qkv * conv_state_size * ggml_element_size(conv_state_all)); |
194 | | |
195 | | // Causal Conv1d function for Q,K,V |
196 | | // When qkv is 0, it is Q, 1 is K, 2 is V |
197 | | // Step 1: Q, K, V projections -> [d_inner, n_tokens] |
198 | 0 | ggml_tensor * x_proj = ggml_mul_mat(ctx0, proj_w, x); |
199 | | |
200 | | // Reshape input: {d_inner, n_tokens} -> {d_inner, n_seq_tokens, n_seqs} |
201 | 0 | ggml_tensor * x_3d = ggml_reshape_3d(ctx0, x_proj, d_inner, n_seq_tokens, n_seqs); |
202 | | |
203 | | // Concat Q conv state and current input: {d_conv-1 + n_seq_tokens, d_inner, n_seqs} |
204 | 0 | ggml_tensor * conv_x = ggml_concat(ctx0, conv_state_x, ggml_transpose(ctx0, x_3d), 0); |
205 | | |
206 | | // Save last (d_conv-1) columns back to Q conv state |
207 | 0 | ggml_tensor * last_conv_x = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, |
208 | 0 | conv_x->nb[1], conv_x->nb[2], n_seq_tokens * conv_x->nb[0]); |
209 | 0 | ggml_build_forward_expand(gf, |
210 | 0 | ggml_cpy(ctx0, last_conv_x, |
211 | 0 | ggml_view_3d(ctx0, conv_states_all, |
212 | 0 | d_conv - 1, d_inner, n_seqs, |
213 | 0 | (d_conv - 1) * ggml_element_size(conv_states_all), // nb1: contiguous within one channel's conv taps |
214 | 0 | n_embd_r_total * ggml_element_size(conv_states_all), // nb2: stride between sequences (skip over K,V states) |
215 | 0 | (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all)))); // offset to first seq's Q/K/V state |
216 | | // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner] |
217 | | // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv] |
218 | | // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step] |
219 | | // ggml_ssm_conv computes: c[conv_step + channel * d_conv] |
220 | | // GGUF layout: [d_conv, 1, d_inner] or [d_conv, 1, d_inner, 1] -> reshape to [d_conv, d_inner] |
221 | | // Reshape conv weight from [d_conv, 1, d_inner, 1] to [d_conv, d_inner] for ggml_ssm_conv |
222 | 0 | ggml_tensor * conv_weight = ggml_reshape_2d(ctx0, conv_w, d_conv, d_inner); |
223 | | |
224 | | // Apply conv1d |
225 | | // ggml_ssm_conv output: {d_inner, n_seq_tokens, n_seqs} |
226 | 0 | ggml_tensor * Xcur = ggml_ssm_conv(ctx0, conv_x, conv_weight); |
227 | | // Reshape to 2D for bias add: {d_inner, n_tokens} |
228 | 0 | Xcur = ggml_reshape_2d(ctx0, Xcur, d_inner, n_tokens); |
229 | 0 | Xcur = ggml_silu(ctx0, Xcur); |
230 | |
|
231 | 0 | return ggml_reshape_4d(ctx0, Xcur, head_dim, n_head, n_seq_tokens, n_seqs); |
232 | 0 | } |
233 | | |
234 | | llama_model_kimi_linear::graph::graph(const llama_model & model, const llm_graph_params & params) : |
235 | 0 | llm_build_delta_net_base(params), model(model) { |
236 | 0 | ggml_tensor * cur; |
237 | 0 | ggml_tensor * inpL; |
238 | |
|
239 | 0 | inpL = build_inp_embd(model.tok_embd); |
240 | 0 | cb(inpL, "model.embed_tokens", -1); |
241 | | |
242 | | // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM) |
243 | | // So we don't need inp_pos |
244 | |
|
245 | 0 | auto * inp_kv = !hparams.is_mla() ? build_inp_mem_hybrid() : nullptr; |
246 | 0 | auto * inp_k = hparams.is_mla() ? build_inp_mem_hybrid_k() : nullptr; |
247 | 0 | auto * inp_rs = hparams.is_mla() ? inp_k->get_recr() : inp_kv->get_recr(); |
248 | 0 | auto * inp_attn_kv = !hparams.is_mla() ? inp_kv->get_attn() : nullptr; |
249 | 0 | auto * inp_attn_k = hparams.is_mla() ? inp_k->get_attn() : nullptr; |
250 | | |
251 | | // Output ids for selecting which tokens to output |
252 | 0 | ggml_tensor * inp_out_ids = build_inp_out_ids(); |
253 | | |
254 | | // Kimi dimension constants |
255 | 0 | const int64_t n_head = hparams.n_head(); |
256 | 0 | const int64_t head_dim = hparams.n_embd_head_kda; |
257 | 0 | const int64_t d_conv = hparams.ssm_d_conv; |
258 | 0 | const int64_t d_inner = n_head * head_dim; // 32 * 128 = 4096 |
259 | 0 | const int64_t n_seqs = ubatch.n_seqs; |
260 | 0 | const int64_t n_seq_tokens = ubatch.n_seq_tokens; |
261 | | |
262 | | // Verify batch consistency for recurrent layers |
263 | 0 | GGML_ASSERT(n_seqs != 0); |
264 | 0 | GGML_ASSERT(ubatch.equal_seqs()); |
265 | 0 | GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); |
266 | | |
267 | | // MLA params |
268 | 0 | const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); |
269 | 0 | const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); |
270 | 0 | const int64_t kv_lora_rank = hparams.n_lora_kv; |
271 | | // qk_rope_head_dim = 64 (from Kimi config) which is hparams.n_rot |
272 | | // Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim] |
273 | 0 | const int64_t n_embd_head_qk_rope = hparams.n_rot(); // config.qk_rope_head_dim |
274 | 0 | const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; // 192 - 64 = 128 |
275 | | // Attention scale for MLA |
276 | 0 | const float kq_scale_mla = 1.0f / sqrtf((float)n_embd_head_k_mla); |
277 | |
|
278 | 0 | for (int il = 0; il < n_layer; ++il) { |
279 | 0 | const auto & layer = model.layers[il]; |
280 | 0 | ggml_tensor * inpSA = inpL; |
281 | | |
282 | | // Attention Norm |
283 | 0 | cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il); |
284 | 0 | cb(cur, "attn_norm", il); |
285 | |
|
286 | 0 | ggml_build_forward_expand(gf, cur); |
287 | |
|
288 | 0 | if (hparams.is_recr(il)) { |
289 | | // === KDA Layer (Kimi Delta Attention) with Recurrent State === |
290 | | // Reference: vLLM kda.py |
291 | 0 | const auto * mctx_cur = inp_rs->mctx; |
292 | 0 | const auto kv_head = mctx_cur->get_head(); |
293 | | |
294 | | // Get conv states from r_l tensor (Q, K, V each have separate state) |
295 | 0 | ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); |
296 | 0 | cb(conv_states_all, "conv_states_all", il); |
297 | 0 | ggml_tensor * conv_state_all = build_rs(inp_rs, conv_states_all, hparams.n_embd_r(), n_seqs); |
298 | 0 | ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, cur, layer.wq, layer.ssm_q_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); |
299 | 0 | ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, cur, layer.wk, layer.ssm_k_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); |
300 | 0 | ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, cur, layer.wv, layer.ssm_v_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); |
301 | | |
302 | | // g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias) |
303 | 0 | ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur); |
304 | 0 | ggml_tensor * g1 = ggml_mul_mat(ctx0, layer.ssm_f_b, f_a); |
305 | 0 | cb(g1, "g1 f_b(f_a(cur))", il); |
306 | 0 | g1 = ggml_add(ctx0, g1, layer.ssm_dt_b); |
307 | 0 | g1 = ggml_softplus(ctx0, g1); |
308 | 0 | g1 = ggml_reshape_3d(ctx0, g1, head_dim, n_head, n_tokens); |
309 | | |
310 | | // A_log shape is [1, n_head] or [1, n_head, 1, 1], need to broadcast to [head_dim, n_head, n_tokens]. No need to -exp(a_log) because it was done in convert_hf_to_gguf.py |
311 | | // Reshape to [1, n_head, 1] for broadcasting with g1 [head_dim, n_head, n_tokens] |
312 | 0 | ggml_tensor * A = ggml_reshape_3d(ctx0, layer.ssm_a, 1, n_head, 1); |
313 | 0 | g1 = ggml_mul(ctx0, g1, A); |
314 | 0 | cb(g1, "kda_g1", il); |
315 | |
|
316 | 0 | g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs); |
317 | | |
318 | | // Compute beta (mixing coefficient) |
319 | 0 | ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur); |
320 | 0 | beta = ggml_reshape_4d(ctx0, beta, 1, n_head, n_seq_tokens, n_seqs); |
321 | 0 | cb(beta, "kda_beta", il); |
322 | |
|
323 | 0 | beta = ggml_sigmoid(ctx0, beta); |
324 | | |
325 | | // Reshape for KDA recurrence |
326 | | // {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs} |
327 | 0 | cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); |
328 | | |
329 | | // Get SSM state and compute KDA recurrence using ggml_kda_scan |
330 | 0 | ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); |
331 | 0 | ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs); |
332 | 0 | state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs); |
333 | |
|
334 | 0 | const float eps_norm = hparams.f_norm_rms_eps; |
335 | |
|
336 | 0 | Qcur = ggml_l2_norm(ctx0, Qcur, eps_norm); |
337 | 0 | Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm); |
338 | | |
339 | | // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens |
340 | 0 | auto attn_out = build_delta_net(Qcur, Kcur, Vcur, g1, beta, state, il); |
341 | |
|
342 | 0 | ggml_tensor * output = ggml_cont(ctx0, attn_out.first); |
343 | 0 | ggml_tensor * new_state = attn_out.second; |
344 | 0 | cb(output, "attn_output", il); |
345 | 0 | cb(new_state, "new_state", il); |
346 | | |
347 | | // Update the recurrent states |
348 | 0 | ggml_build_forward_expand(gf, |
349 | 0 | ggml_cpy(ctx0, new_state, |
350 | 0 | ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, |
351 | 0 | kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); |
352 | | |
353 | | // Output gating g2 = g_b(g_a(x)) |
354 | 0 | ggml_tensor * cur_2d = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); |
355 | 0 | ggml_tensor * g_a = ggml_mul_mat(ctx0, layer.ssm_g_a, cur_2d); |
356 | 0 | ggml_tensor * g2 = ggml_mul_mat(ctx0, layer.ssm_g_b, g_a); |
357 | 0 | cb(g2, "g2 g_b(g_a(cur_2d))", il); |
358 | 0 | g2 = ggml_reshape_3d(ctx0, g2, head_dim, n_head, n_seq_tokens * n_seqs); |
359 | | |
360 | | // Apply o_norm with sigmoid gating |
361 | | // Note: Kimi model uses sigmoid gating, not SiLU (despite FusedRMSNormGated default being swish) |
362 | | // Formula: output = RMSNorm(x) * sigmoid(g) |
363 | 0 | ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, output, head_dim, n_head, n_seq_tokens * n_seqs); |
364 | 0 | ggml_tensor * normed = build_norm(attn_out_final, layer.ssm_o_norm, nullptr, LLM_NORM_RMS, il); |
365 | 0 | cb(normed, "kda_normed", il); |
366 | 0 | ggml_tensor * gate = ggml_sigmoid(ctx0, g2); |
367 | 0 | ggml_tensor * gated = ggml_mul(ctx0, normed, gate); |
368 | | |
369 | | // Output projection |
370 | 0 | gated = ggml_cont_2d(ctx0, gated, d_inner, n_tokens); |
371 | 0 | cur = ggml_mul_mat(ctx0, layer.wo, gated); |
372 | 0 | cb(cur, "kda_out", il); |
373 | |
|
374 | 0 | } else { |
375 | | // === MLA Layer (Multi-head Latent Attention) without KV Cache === |
376 | | // Reference: vLLM mla.py |
377 | | // Step 1: Q projection and reshape |
378 | | // vLLM Kimi: q = q_proj(hidden_states), then view as [n_tokens, n_head, qk_head_dim] |
379 | | // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM) |
380 | 0 | ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.wq, cur); |
381 | | |
382 | | // Step 2: KV compression |
383 | | // kv_cmpr_pe = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens] |
384 | 0 | ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur); |
385 | | |
386 | | // Split: kv_cmpr = kv_lora[:kv_lora_rank], k_pe = kv_lora[kv_lora_rank:] |
387 | 0 | ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens, |
388 | 0 | ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); |
389 | 0 | ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens, |
390 | 0 | ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), |
391 | 0 | ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), |
392 | 0 | ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); |
393 | | // Note: Kimi MLA does NOT apply RoPE (rotary_emb=None in vLLM) |
394 | | // k_pe is used directly without RoPE |
395 | | // Normalize kv_c |
396 | 0 | kv_cmpr = build_norm(kv_cmpr, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il); |
397 | |
|
398 | 0 | if (layer.wk_b && layer.wv_b) { // MLA KV cache enabled |
399 | | // extract q_nope |
400 | 0 | ggml_tensor * q_nope = |
401 | 0 | ggml_view_3d(ctx0, Qcur, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla), |
402 | 0 | ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, 0); |
403 | 0 | cb(q_nope, "q_nope", il); |
404 | | |
405 | | // and {n_embd_head_qk_rope, n_head, n_tokens} |
406 | 0 | ggml_tensor * q_pe = ggml_view_3d( |
407 | 0 | ctx0, Qcur, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla), |
408 | 0 | ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, ggml_row_size(Qcur->type, n_embd_head_qk_nope)); |
409 | 0 | cb(q_pe, "q_pe", il); |
410 | | |
411 | | // {n_embd_head_qk_nope, n_tokens, n_head} |
412 | 0 | q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); |
413 | 0 | cb(q_nope, "q_nope_perm", il); |
414 | | |
415 | | // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head} |
416 | 0 | ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, layer.wk_b, q_nope); |
417 | 0 | cb(q_nope_absorbed, "q_nope_absorbed", il); |
418 | | |
419 | | // {kv_lora_rank, n_head, n_tokens} |
420 | 0 | q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); |
421 | 0 | cb(q_nope_absorbed, "q_nope_absorbed_perm", il); |
422 | | |
423 | | // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} |
424 | | // note: rope must go first for in-place context shifting in build_rope_shift() |
425 | 0 | Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); |
426 | 0 | cb(Qcur, "Qcur", il); |
427 | |
|
428 | 0 | kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); |
429 | 0 | cb(kv_cmpr, "kv_cmpr_reshape", il); |
430 | | |
431 | | // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} |
432 | 0 | ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0); |
433 | 0 | cb(Kcur, "Kcur", il); |
434 | | |
435 | | // {kv_lora_rank, 1, n_tokens} |
436 | 0 | ggml_tensor * Vcur = kv_cmpr; |
437 | 0 | cb(Vcur, "Vcur", il); |
438 | |
|
439 | 0 | cur = build_attn(inp_attn_k, layer.wo, NULL, layer.wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il); |
440 | 0 | cb(cur, "mla_out", il); |
441 | 0 | } else { // MLA KV cache disabled. Fall back to MHA KV cache. |
442 | 0 | Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens); |
443 | 0 | cb(Qcur, "mla_Q", il); |
444 | | // KV decompression: kv = kv_b_proj(kv_c_normed) |
445 | 0 | ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr); |
446 | 0 | const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla; |
447 | | |
448 | | // Split kv into k_nope and v |
449 | 0 | ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, |
450 | 0 | ggml_row_size(kv->type, kv_per_head), |
451 | 0 | ggml_row_size(kv->type, kv_per_head * n_head), 0); |
452 | 0 | ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens, |
453 | 0 | ggml_row_size(kv->type, kv_per_head), |
454 | 0 | ggml_row_size(kv->type, kv_per_head * n_head), |
455 | 0 | ggml_row_size(kv->type, n_embd_head_qk_nope)); |
456 | 0 | Vcur = ggml_cont(ctx0, Vcur); |
457 | 0 | cb(Vcur, "mla_V", il); |
458 | | |
459 | | // Concatenate k_nope + k_pe (broadcast k_pe to all heads) |
460 | | // K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens] |
461 | | // and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads |
462 | | // Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens] |
463 | 0 | ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens); |
464 | 0 | ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target); |
465 | 0 | ggml_tensor * Kcur = ggml_concat(ctx0, k_pe_repeated, k_nope, 0); |
466 | 0 | cb(Kcur, "mla_K", il); |
467 | | |
468 | | // Direct softmax attention (with MHA KV cache) |
469 | | // Use build_attn with inp_attn for proper mask handling |
470 | 0 | cur = build_attn(inp_attn_kv, layer.wo, NULL, layer.wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il); |
471 | 0 | cb(cur, "mla_out", il); |
472 | 0 | } |
473 | 0 | } |
474 | | |
475 | | // On last layer, select only the output tokens |
476 | 0 | if (il == n_layer - 1 && inp_out_ids) { |
477 | 0 | cur = ggml_get_rows(ctx0, cur, inp_out_ids); |
478 | 0 | inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); |
479 | 0 | } |
480 | | |
481 | | // Residual |
482 | 0 | ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); |
483 | 0 | cb(ffn_inp, "ffn_inp", il); |
484 | | |
485 | | // FFN Norm |
486 | 0 | cur = build_norm(ffn_inp, layer.ffn_norm, NULL, LLM_NORM_RMS, il); |
487 | 0 | cb(cur, "ffn_norm", il); |
488 | |
|
489 | 0 | if ((uint32_t) il < hparams.n_layer_dense_lead) { |
490 | | // Dense FFN layer |
491 | 0 | cur = build_ffn(cur, |
492 | 0 | layer.ffn_up, NULL, NULL, |
493 | 0 | layer.ffn_gate, NULL, NULL, |
494 | 0 | layer.ffn_down, NULL, NULL, |
495 | 0 | NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); |
496 | 0 | cb(cur, "ffn_out", il); |
497 | 0 | } else { |
498 | | // MoE layer |
499 | | // Kimi uses moe_renormalize=True and routed_scaling_factor (stored as expert_weights_scale) = 2.446 |
500 | 0 | ggml_tensor * moe_out = build_moe_ffn(cur, |
501 | 0 | layer.ffn_gate_inp, |
502 | 0 | layer.ffn_up_exps, |
503 | 0 | layer.ffn_gate_exps, |
504 | 0 | layer.ffn_down_exps, |
505 | 0 | layer.ffn_exp_probs_b, |
506 | 0 | hparams.n_expert, |
507 | 0 | hparams.n_expert_used, |
508 | 0 | LLM_FFN_SILU, true, |
509 | 0 | hparams.expert_weights_scale, |
510 | 0 | (llama_expert_gating_func_type) hparams.expert_gating_func, |
511 | 0 | il); |
512 | 0 | cb(moe_out, "ffn_moe_out", il); |
513 | | |
514 | | // Shared expert |
515 | 0 | { |
516 | 0 | ggml_tensor * ffn_shexp = build_ffn(cur, |
517 | 0 | layer.ffn_up_shexp, NULL, NULL, |
518 | 0 | layer.ffn_gate_shexp, NULL, NULL, |
519 | 0 | layer.ffn_down_shexp, NULL, NULL, |
520 | 0 | NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); |
521 | 0 | cb(ffn_shexp, "ffn_shexp", il); |
522 | |
|
523 | 0 | cur = ggml_add(ctx0, moe_out, ffn_shexp); |
524 | 0 | cb(cur, "ffn_out", il); |
525 | 0 | } |
526 | 0 | } |
527 | | // Residual |
528 | 0 | cur = ggml_add(ctx0, cur, ffn_inp); |
529 | |
|
530 | 0 | cur = build_cvec(cur, il); |
531 | 0 | cb(cur, "l_out", il); |
532 | | |
533 | | // input for next layer |
534 | 0 | inpL = cur; |
535 | 0 | } |
536 | 0 | cur = inpL; |
537 | | |
538 | | // Final Norm |
539 | 0 | cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); |
540 | |
|
541 | 0 | cb(cur, "result_norm", -1); |
542 | 0 | res->t_embd = cur; |
543 | | |
544 | | // Output |
545 | 0 | cur = ggml_mul_mat(ctx0, model.output, cur); |
546 | 0 | cb(cur, "result_output", -1); |
547 | 0 | res->t_logits = cur; |
548 | |
|
549 | 0 | ggml_build_forward_expand(gf, cur); |
550 | 0 | } |