Coverage Report

Created: 2026-03-07 06:35

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/models/kimi-linear.cpp
Line
Count
Source
1
#include "models.h"
2
#include "ggml.h"
3
4
#include "llama-memory-recurrent.h"
5
6
// Causal Conv1d function for Q,K,V
7
// When qkv is 0, it is Q, 1 is K, 2 is V
8
0
static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) {
9
0
    const int64_t d_inner = head_dim * n_head;
10
0
    const int64_t conv_state_size = (d_conv - 1) * d_inner;
11
0
    const int64_t n_embd_r_total = 3 * conv_state_size;  // Q + K + V
12
13
    // conv_state_all is [n_embd_r_total, n_seqs], split into Q, K, V
14
    // Each conv state is [(d_conv-1) * d_inner] per sequence, need to reshape to [d_conv-1, d_inner, n_seqs]
15
    // Memory layout: for each seq, Q state is first conv_state_size elements, then K, then V
16
    // conv_state_all has stride: nb[0] = element_size, nb[1] = n_embd_r_total * element_size
17
    // View Q conv state: offset 0, size conv_state_size per seq
18
    // conv_state_all is [n_embd_r_total, n_seqs] with memory layout:
19
    //   state[i + seq * n_embd_r_total] where i = conv_step + channel * (d_conv-1) + {0, conv_state_size, 2*conv_state_size} for Q/K/V
20
    // We want [d_conv-1, d_inner, n_seqs] view:
21
    //   nb1 = (d_conv-1) * element_size (stride between channels)
22
    //   nb2 = n_embd_r_total * element_size (stride between seqs)
23
0
    ggml_tensor * conv_state_x = ggml_view_3d(ctx0, conv_state_all, d_conv - 1, d_inner, n_seqs,
24
0
        (d_conv - 1) * ggml_element_size(conv_state_all),  // nb1: stride between channels
25
0
        n_embd_r_total * ggml_element_size(conv_state_all),  // nb2: stride between seqs
26
0
        qkv * conv_state_size * ggml_element_size(conv_state_all));
27
28
// Causal Conv1d function for Q,K,V
29
// When qkv is 0, it is Q, 1 is K, 2 is V
30
    // Step 1: Q, K, V projections -> [d_inner, n_tokens]
31
0
    ggml_tensor * x_proj = ggml_mul_mat(ctx0, proj_w, x);
32
33
    // Reshape input: {d_inner, n_tokens} -> {d_inner, n_seq_tokens, n_seqs}
34
0
    ggml_tensor * x_3d = ggml_reshape_3d(ctx0, x_proj, d_inner, n_seq_tokens, n_seqs);
35
36
    // Concat Q conv state and current input: {d_conv-1 + n_seq_tokens, d_inner, n_seqs}
37
0
    ggml_tensor * conv_x = ggml_concat(ctx0, conv_state_x, ggml_transpose(ctx0, x_3d), 0);
38
39
    // Save last (d_conv-1) columns back to Q conv state
40
0
    ggml_tensor * last_conv_x = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs,
41
0
        conv_x->nb[1], conv_x->nb[2], n_seq_tokens * conv_x->nb[0]);
42
0
    ggml_build_forward_expand(gf,
43
0
        ggml_cpy(ctx0, last_conv_x,
44
0
            ggml_view_3d(ctx0, conv_states_all,
45
0
                d_conv - 1, d_inner, n_seqs,
46
0
                (d_conv - 1) * ggml_element_size(conv_states_all),           // nb1: contiguous within one channel's conv taps
47
0
                n_embd_r_total * ggml_element_size(conv_states_all),         // nb2: stride between sequences (skip over K,V states)
48
0
                (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all))));  // offset to first seq's Q/K/V state
49
    // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner]
50
    // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv]
51
    // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step]
52
    // ggml_ssm_conv computes: c[conv_step + channel * d_conv]
53
    // GGUF layout: [d_conv, 1, d_inner] or [d_conv, 1, d_inner, 1] -> reshape to [d_conv, d_inner]
54
    // Reshape conv weight from [d_conv, 1, d_inner, 1] to [d_conv, d_inner] for ggml_ssm_conv
55
0
    ggml_tensor * conv_weight = ggml_reshape_2d(ctx0, conv_w, d_conv, d_inner);
56
57
    // Apply conv1d
58
    // ggml_ssm_conv output: {d_inner, n_seq_tokens, n_seqs}
59
0
    ggml_tensor * Xcur = ggml_ssm_conv(ctx0, conv_x, conv_weight);
60
    // Reshape to 2D for bias add: {d_inner, n_tokens}
61
0
    Xcur = ggml_reshape_2d(ctx0, Xcur, d_inner, n_tokens);
62
0
    Xcur = ggml_silu(ctx0, Xcur);
63
64
0
    return ggml_reshape_4d(ctx0, Xcur, head_dim, n_head, n_seq_tokens, n_seqs);
65
0
}
66
67
llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) :
68
0
    llm_build_delta_net_base(params), model(model) {
69
0
    ggml_tensor * cur;
70
0
    ggml_tensor * inpL;
71
72
0
    inpL = build_inp_embd(model.tok_embd);
73
0
    cb(inpL, "model.embed_tokens", -1);
74
75
    // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
76
    // So we don't need inp_pos
77
78
0
    auto * inp_kv = !hparams.is_mla() ? build_inp_mem_hybrid() : nullptr;
79
0
    auto * inp_k = hparams.is_mla() ? build_inp_mem_hybrid_k() : nullptr;
80
0
    auto * inp_rs = hparams.is_mla() ? inp_k->get_recr() : inp_kv->get_recr();
81
0
    auto * inp_attn_kv = !hparams.is_mla() ? inp_kv->get_attn() : nullptr;
82
0
    auto * inp_attn_k = hparams.is_mla() ? inp_k->get_attn() : nullptr;
83
84
    // Output ids for selecting which tokens to output
85
0
    ggml_tensor * inp_out_ids = build_inp_out_ids();
86
87
    // Kimi dimension constants
88
0
    const int64_t n_head = hparams.n_head();
89
0
    const int64_t head_dim = hparams.n_embd_head_kda;
90
0
    const int64_t d_conv = hparams.ssm_d_conv;
91
0
    const int64_t d_inner = n_head * head_dim;  // 32 * 128 = 4096
92
0
    const int64_t n_seqs = ubatch.n_seqs;
93
0
    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
94
95
    // Verify batch consistency for recurrent layers
96
0
    GGML_ASSERT(n_seqs != 0);
97
0
    GGML_ASSERT(ubatch.equal_seqs());
98
0
    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
99
100
    // MLA params
101
0
    const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
102
0
    const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
103
0
    const int64_t kv_lora_rank = hparams.n_lora_kv;
104
    // qk_rope_head_dim = 64 (from Kimi config) which is hparams.n_rot
105
    // Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim]
106
0
    const int64_t n_embd_head_qk_rope = hparams.n_rot;  // config.qk_rope_head_dim
107
0
    const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;  // 192 - 64 = 128
108
    // Attention scale for MLA
109
0
    const float kq_scale_mla = 1.0f / sqrtf((float)n_embd_head_k_mla);
110
111
0
    for (int il = 0; il < n_layer; ++il) {
112
0
        const auto & layer = model.layers[il];
113
0
        ggml_tensor * inpSA = inpL;
114
115
        // Attention Norm
116
0
        cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il);
117
0
        cb(cur, "attn_norm", il);
118
119
0
        ggml_build_forward_expand(gf, cur);
120
121
        // Check layer type by checking which tensors exist
122
        // KDA layers have ssm_a_log tensor, MLA layers have wkv_a_mqa tensor
123
0
        bool is_kda = (layer.ssm_a != nullptr);
124
0
        bool is_mla = (layer.wkv_a_mqa != nullptr);
125
126
0
        if (is_kda) {
127
            // === KDA Layer (Kimi Delta Attention) with Recurrent State ===
128
            // Reference: vLLM kda.py
129
0
            const auto * mctx_cur = inp_rs->mctx;
130
0
            const auto kv_head = mctx_cur->get_head();
131
132
            // Get conv states from r_l tensor (Q, K, V each have separate state)
133
0
            ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
134
0
            cb(conv_states_all, "conv_states_all", il);
135
0
            ggml_tensor * conv_state_all = build_rs(inp_rs, conv_states_all, hparams.n_embd_r(), n_seqs);
136
0
            ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, cur, layer.wq, layer.ssm_q_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
137
0
            ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, cur, layer.wk, layer.ssm_k_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
138
0
            ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, cur, layer.wv, layer.ssm_v_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
139
140
            // g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias)
141
0
            ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur);
142
0
            ggml_tensor * g1 = ggml_mul_mat(ctx0, layer.ssm_f_b, f_a);
143
0
            cb(g1, "g1 f_b(f_a(cur))", il);
144
0
            g1 = ggml_add(ctx0, g1, layer.ssm_dt_b);
145
0
            g1 = ggml_softplus(ctx0, g1);
146
0
            g1 = ggml_reshape_3d(ctx0, g1, head_dim, n_head, n_tokens);
147
148
            // A_log shape is [1, n_head] or [1, n_head, 1, 1], need to broadcast to [head_dim, n_head, n_tokens]. No need to -exp(a_log) because it was done in convert_hf_to_gguf.py
149
            // Reshape to [1, n_head, 1] for broadcasting with g1 [head_dim, n_head, n_tokens]
150
0
            ggml_tensor * A = ggml_reshape_3d(ctx0, layer.ssm_a, 1, n_head, 1);
151
0
            g1 = ggml_mul(ctx0, g1, A);
152
0
            cb(g1, "kda_g1", il);
153
154
0
            g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs);
155
156
            // Compute beta (mixing coefficient)
157
0
            ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur);
158
0
            beta = ggml_reshape_4d(ctx0, beta, 1, n_head, n_seq_tokens, n_seqs);
159
0
            cb(beta, "kda_beta", il);
160
161
0
            beta = ggml_sigmoid(ctx0, beta);
162
163
            // Reshape for KDA recurrence
164
            // {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs}
165
0
            cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
166
167
            // Get SSM state and compute KDA recurrence using ggml_kda_scan
168
0
            ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
169
0
            ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs);
170
0
            state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs);
171
172
0
            const float eps_norm = hparams.f_norm_rms_eps;
173
174
0
            Qcur = ggml_l2_norm(ctx0, Qcur, eps_norm);
175
0
            Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm);
176
177
            // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens
178
0
            std::pair<ggml_tensor *, ggml_tensor *> attn_out = n_seq_tokens == 1 ?
179
0
                build_delta_net_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
180
0
                build_delta_net_chunking(Qcur, Kcur, Vcur, g1, beta, state, il);
181
182
0
            ggml_tensor * output = ggml_cont(ctx0, attn_out.first);
183
0
            ggml_tensor * new_state = attn_out.second;
184
0
            cb(output, "attn_output", il);
185
0
            cb(new_state, "new_state", il);
186
187
            // Update the recurrent states
188
0
            ggml_build_forward_expand(gf,
189
0
                                     ggml_cpy(ctx0, new_state,
190
0
                                              ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
191
0
                                                           kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
192
193
            // Output gating g2 = g_b(g_a(x))
194
0
            ggml_tensor * cur_2d = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
195
0
            ggml_tensor * g_a = ggml_mul_mat(ctx0, layer.ssm_g_a, cur_2d);
196
0
            ggml_tensor * g2 = ggml_mul_mat(ctx0, layer.ssm_g_b, g_a);
197
0
            cb(g2, "g2 g_b(g_a(cur_2d))", il);
198
0
            g2 = ggml_reshape_3d(ctx0, g2, head_dim, n_head, n_seq_tokens * n_seqs);
199
200
            // Apply o_norm with sigmoid gating
201
            // Note: Kimi model uses sigmoid gating, not SiLU (despite FusedRMSNormGated default being swish)
202
            // Formula: output = RMSNorm(x) * sigmoid(g)
203
0
            ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, output, head_dim, n_head,  n_seq_tokens * n_seqs);
204
0
            ggml_tensor * normed = build_norm(attn_out_final, layer.ssm_o_norm, nullptr, LLM_NORM_RMS, il);
205
0
            cb(normed, "kda_normed", il);
206
0
            ggml_tensor * gate = ggml_sigmoid(ctx0, g2);
207
0
            ggml_tensor * gated = ggml_mul(ctx0, normed, gate);
208
209
            // Output projection
210
0
            gated = ggml_cont_2d(ctx0, gated, d_inner, n_tokens);
211
0
            cur = ggml_mul_mat(ctx0, layer.wo, gated);
212
0
            cb(cur, "kda_out", il);
213
214
0
        } else if (is_mla) {
215
            // === MLA Layer (Multi-head Latent Attention) without KV Cache ===
216
            // Reference: vLLM mla.py
217
            // Step 1: Q projection and reshape
218
            // vLLM Kimi: q = q_proj(hidden_states), then view as [n_tokens, n_head, qk_head_dim]
219
            // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
220
0
            ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.wq, cur);
221
222
            // Step 2: KV compression
223
            // kv_cmpr_pe = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens]
224
0
            ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur);
225
226
            // Split: kv_cmpr = kv_lora[:kv_lora_rank], k_pe = kv_lora[kv_lora_rank:]
227
0
            ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens,
228
0
                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0);
229
0
            ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens,
230
0
                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
231
0
                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
232
0
                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
233
            // Note: Kimi MLA does NOT apply RoPE (rotary_emb=None in vLLM)
234
            // k_pe is used directly without RoPE
235
            // Normalize kv_c
236
0
            kv_cmpr = build_norm(kv_cmpr, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
237
238
0
            if (layer.wk_b && layer.wv_b) { // MLA KV cache enabled
239
                // extract q_nope
240
0
                ggml_tensor * q_nope =
241
0
                    ggml_view_3d(ctx0, Qcur, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
242
0
                                 ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, 0);
243
0
                cb(q_nope, "q_nope", il);
244
245
                // and {n_embd_head_qk_rope, n_head, n_tokens}
246
0
                ggml_tensor * q_pe = ggml_view_3d(
247
0
                    ctx0, Qcur, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
248
0
                    ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, ggml_row_size(Qcur->type, n_embd_head_qk_nope));
249
0
                cb(q_pe, "q_pe", il);
250
251
                // {n_embd_head_qk_nope, n_tokens, n_head}
252
0
                q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
253
0
                cb(q_nope, "q_nope_perm", il);
254
255
                // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
256
0
                ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, layer.wk_b, q_nope);
257
0
                cb(q_nope_absorbed, "q_nope_absorbed", il);
258
259
                // {kv_lora_rank, n_head, n_tokens}
260
0
                q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
261
0
                cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
262
263
                // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
264
                // note: rope must go first for in-place context shifting in build_rope_shift()
265
0
                Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
266
0
                cb(Qcur, "Qcur", il);
267
268
0
                kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
269
0
                cb(kv_cmpr, "kv_cmpr_reshape", il);
270
271
                // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
272
0
                ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
273
0
                cb(Kcur, "Kcur", il);
274
275
                // {kv_lora_rank, 1, n_tokens}
276
0
                ggml_tensor * Vcur = kv_cmpr;
277
0
                cb(Vcur, "Vcur", il);
278
279
0
                cur = build_attn(inp_attn_k, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il);
280
0
                cb(cur, "mla_out", il);
281
0
            } else { // MLA KV cache disabled. Fall back to MHA KV cache.
282
0
                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens);
283
0
                cb(Qcur, "mla_Q", il);
284
                // KV decompression: kv = kv_b_proj(kv_c_normed)
285
0
                ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr);
286
0
                const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla;
287
288
                // Split kv into k_nope and v
289
0
                ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
290
0
                    ggml_row_size(kv->type, kv_per_head),
291
0
                    ggml_row_size(kv->type, kv_per_head * n_head), 0);
292
0
                ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens,
293
0
                    ggml_row_size(kv->type, kv_per_head),
294
0
                    ggml_row_size(kv->type, kv_per_head * n_head),
295
0
                    ggml_row_size(kv->type, n_embd_head_qk_nope));
296
0
                Vcur = ggml_cont(ctx0, Vcur);
297
0
                cb(Vcur, "mla_V", il);
298
299
                // Concatenate k_nope + k_pe (broadcast k_pe to all heads)
300
                // K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens]
301
                // and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads
302
                // Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens]
303
0
                ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens);
304
0
                ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target);
305
0
                ggml_tensor * Kcur = ggml_concat(ctx0, k_pe_repeated, k_nope, 0);
306
0
                cb(Kcur, "mla_K", il);
307
308
                // Direct softmax attention (with MHA KV cache)
309
                // Use build_attn with inp_attn for proper mask handling
310
0
                cur = build_attn(inp_attn_kv, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
311
0
                cb(cur, "mla_out", il);
312
0
            }
313
0
        } else {
314
            // Unknown layer type - this should not happen
315
0
            GGML_ABORT("Kimi layer is neither KDA nor MLA - missing required tensors");
316
0
        }
317
318
        // On last layer, select only the output tokens
319
0
        if (il == n_layer - 1 && inp_out_ids) {
320
0
            cur   = ggml_get_rows(ctx0, cur,   inp_out_ids);
321
0
            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
322
0
        }
323
324
        // Residual
325
0
        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
326
0
        cb(ffn_inp, "ffn_inp", il);
327
328
        // FFN Norm
329
0
        cur = build_norm(ffn_inp, layer.ffn_norm, NULL, LLM_NORM_RMS, il);
330
0
        cb(cur, "ffn_norm", il);
331
332
0
        if ((uint32_t) il < hparams.n_layer_dense_lead) {
333
            // Dense FFN layer
334
0
            cur = build_ffn(cur,
335
0
                layer.ffn_up, NULL, NULL,
336
0
                layer.ffn_gate, NULL, NULL,
337
0
                layer.ffn_down, NULL, NULL,
338
0
                NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
339
0
            cb(cur, "ffn_out", il);
340
0
        } else {
341
            // MoE layer
342
            // Kimi uses moe_renormalize=True and routed_scaling_factor (stored as expert_weights_scale) = 2.446
343
0
            ggml_tensor * moe_out = build_moe_ffn(cur,
344
0
                layer.ffn_gate_inp,
345
0
                layer.ffn_up_exps,
346
0
                layer.ffn_gate_exps,
347
0
                layer.ffn_down_exps,
348
0
                layer.ffn_exp_probs_b,
349
0
                hparams.n_expert,
350
0
                hparams.n_expert_used,
351
0
                LLM_FFN_SILU, true,
352
0
                true, hparams.expert_weights_scale,
353
0
                (llama_expert_gating_func_type) hparams.expert_gating_func,
354
0
                il);
355
0
            cb(moe_out, "ffn_moe_out", il);
356
357
            // Shared expert
358
0
            {
359
0
                ggml_tensor * ffn_shexp = build_ffn(cur,
360
0
                        layer.ffn_up_shexp, NULL, NULL,
361
0
                        layer.ffn_gate_shexp, NULL, NULL,
362
0
                        layer.ffn_down_shexp, NULL, NULL,
363
0
                        NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
364
0
                cb(ffn_shexp, "ffn_shexp", il);
365
366
0
                cur = ggml_add(ctx0, moe_out, ffn_shexp);
367
0
                cb(cur, "ffn_out", il);
368
0
            }
369
0
        }
370
        // Residual
371
0
        cur = ggml_add(ctx0, cur, ffn_inp);
372
373
0
        cur = build_cvec(cur, il);
374
0
        cb(cur, "l_out", il);
375
376
0
        inpL = cur;
377
0
    }
378
0
    cur = inpL;
379
380
    // Final Norm
381
0
    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
382
383
0
    cb(cur, "result_norm", -1);
384
0
    res->t_embd = cur;
385
386
    // Output
387
0
    cur = ggml_mul_mat(ctx0, model.output, cur);
388
0
    cb(cur, "result_output", -1);
389
0
    res->t_logits = cur;
390
391
0
    ggml_build_forward_expand(gf, cur);
392
0
}