Coverage Report

Created: 2026-06-13 06:23

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/models/plamo2.cpp
Line
Count
Source
1
#include "models.h"
2
#include "llama-memory-recurrent.h"
3
4
0
void llama_model_plamo2::load_arch_hparams(llama_model_loader & ml) {
5
0
    ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
6
7
    // Load Mamba SSM parameters
8
0
    ml.get_key(LLM_KV_SSM_CONV_KERNEL,    hparams.ssm_d_conv);
9
0
    ml.get_key(LLM_KV_SSM_INNER_SIZE,     hparams.ssm_d_inner);
10
0
    ml.get_key(LLM_KV_SSM_STATE_SIZE,     hparams.ssm_d_state);
11
0
    ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
12
0
    ml.get_key(LLM_KV_SSM_GROUP_COUNT,    hparams.ssm_n_group);
13
14
    // Load attention parameters
15
0
    ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH,   hparams.n_embd_head_k_full, false);
16
0
    ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false);
17
18
0
    for (uint32_t i = 0; i < hparams.n_layer(); ++i) {
19
0
        hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0;
20
0
    }
21
22
0
    switch (hparams.n_layer()) {
23
0
        case 16: type = LLM_TYPE_1B; break;
24
0
        case 32:
25
0
            if (hparams.n_embd == 2048) {
26
0
                type = LLM_TYPE_2B;
27
0
            } else if (hparams.n_embd == 4096) {
28
0
                type = LLM_TYPE_8B;
29
0
            }
30
0
            break;
31
0
        default: type = LLM_TYPE_UNKNOWN;
32
0
    }
33
0
}
34
35
0
void llama_model_plamo2::load_arch_tensors(llama_model_loader &) {
36
0
    LLAMA_LOAD_LOCALS;
37
38
    // mamba parameters
39
0
    const uint32_t d_conv             = hparams.ssm_d_conv;
40
0
    const uint32_t d_state            = hparams.ssm_d_state;
41
0
    const uint32_t num_heads          = hparams.ssm_dt_rank;
42
0
    const uint32_t intermediate_size  = hparams.ssm_d_inner;
43
0
    const int64_t dt_dim              = std::max(64, int(hparams.n_embd / 16));
44
45
    // attention parameters
46
0
    const uint32_t qk_dim = hparams.n_embd_head_k();
47
0
    const uint32_t v_dim  = hparams.n_embd_head_v();
48
49
0
    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
50
51
    // output
52
0
    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
53
0
    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
54
    // if output is NULL, init from the input tok embed
55
0
    if (output == NULL) {
56
0
        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
57
0
    }
58
59
0
    for (int i = 0; i < n_layer; ++i) {
60
0
        auto & layer = layers[i];
61
0
        bool is_mamba_layer = hparams.is_recr(i);
62
63
0
        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
64
65
0
        if (is_mamba_layer) {
66
0
            layer.ssm_in       = create_tensor(tn(LLM_TENSOR_SSM_IN,     "weight", i), {n_embd, 2 * intermediate_size}, 0);
67
0
            layer.ssm_conv1d   = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0);
68
69
0
            layer.ssm_x    = create_tensor(tn(LLM_TENSOR_SSM_X,  "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0);
70
0
            layer.ssm_dt   = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0);
71
0
            layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0);
72
73
0
            layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0);
74
0
            layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0);
75
76
0
            layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0);
77
78
0
            layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0);
79
0
            layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0);
80
0
            layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0);
81
0
        } else {
82
0
            const int64_t num_attention_heads = hparams.n_head(i);
83
0
            const int64_t q_num_heads         = num_attention_heads;
84
0
            const int64_t num_key_value_heads = hparams.n_head_kv(i);
85
0
            const int64_t k_num_heads         = num_key_value_heads;
86
0
            const int64_t v_num_heads         = num_key_value_heads;
87
0
            const int64_t q_proj_dim          = q_num_heads * qk_dim;
88
0
            const int64_t k_proj_dim          = k_num_heads * qk_dim;
89
0
            const int64_t v_proj_dim          = v_num_heads * v_dim;
90
91
0
            layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0);
92
0
            layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0);
93
0
            layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0);
94
0
            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0);
95
0
        }
96
97
        // All layers have post-attention norm, FFN norm, and FFN tensors
98
0
        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0);
99
0
        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
100
0
        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
101
0
        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff * 2}, 0);
102
0
        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0);
103
0
    }
104
0
}
105
106
0
std::unique_ptr<llm_graph_context> llama_model_plamo2::build_arch_graph(const llm_graph_params & params) const {
107
0
    return std::make_unique<graph>(*this, params);
108
0
}
109
110
llama_model_plamo2::graph::graph(const llama_model & model, const llm_graph_params & params) :
111
0
    llm_build_mamba_base(params) {
112
0
    ggml_tensor * cur;
113
0
    ggml_tensor * inpL;
114
115
    // {n_embd, n_tokens}
116
0
    inpL = build_inp_embd(model.tok_embd);
117
0
    cb(inpL, "embedding_output", -1);
118
119
0
    ggml_tensor * inp_pos = build_inp_pos();
120
121
0
    auto * inp_hybrid = build_inp_mem_hybrid();
122
123
0
    ggml_tensor * inp_out_ids = build_inp_out_ids();
124
125
0
    for (int il = 0; il < n_layer; ++il) {
126
0
        ggml_tensor * residual = inpL;
127
128
        // ggml_graph_add_node(gf, model.layers[il].attn_norm);
129
        // cb(model.layers[il].attn_norm, "attn_norm", il);
130
131
        // pre_mixer_norm
132
0
        cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
133
134
        // check if this layer is Mamba or Attention
135
0
        const bool is_mamba_layer = hparams.is_recr(il);
136
137
0
        if (is_mamba_layer) {
138
            // PLaMo-2 Mamba layer
139
0
            cur = build_plamo2_mamba_layer(inp_hybrid->get_recr(), cur, model, ubatch, il);
140
0
        } else {
141
            // PLaMo-2 Attention layer
142
0
            cur = build_plamo2_attn_layer(inp_hybrid->get_attn(), inp_pos, cur, model, il);
143
0
        }
144
145
        // post_mixer_norm
146
0
        cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
147
0
        cb(cur, "attn_post_norm", il);
148
149
        // residual connection
150
0
        cur = ggml_add(ctx0, cur, residual);
151
0
        cb(cur, "attn_residual", il);
152
0
        residual = cur;
153
154
        // pre-ffn norm
155
0
        cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
156
0
        cb(cur, "ffn_pre_norm", il);
157
158
        // feed-forward network
159
0
        cur = build_ffn(cur,
160
0
                model.layers[il].ffn_up, NULL, NULL,
161
0
                NULL, NULL, NULL,
162
0
                model.layers[il].ffn_down, NULL, NULL,
163
0
                NULL, LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
164
0
        cb(cur, "ffn_out", il);
165
166
        // post ffn norm
167
0
        cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il);
168
0
        cb(cur, "ffn_post_norm", il);
169
170
0
        if (il == n_layer - 1 && inp_out_ids) {
171
0
            cur      = ggml_get_rows(ctx0, cur, inp_out_ids);
172
0
            residual = ggml_get_rows(ctx0, residual, inp_out_ids);
173
0
        }
174
175
        // residual connection
176
0
        cur = ggml_add(ctx0, cur, residual);
177
0
        cb(cur, "ffn_residual", il);
178
179
        // input for next layer
180
0
        inpL = cur;
181
0
    }
182
183
0
    cur = inpL;
184
185
    // final norm
186
0
    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
187
0
    cb(cur, "result_norm", -1);
188
189
0
    res->t_embd = cur;
190
191
    // lm_head
192
0
    cur = build_lora_mm(model.output, cur, model.output_s);
193
0
    cb(cur, "result_output", -1);
194
195
    // Explicitly mark as output tensor to ensure proper backend assignment
196
0
    ggml_set_output(cur);
197
198
0
    res->t_logits = cur;
199
200
0
    ggml_build_forward_expand(gf, cur);
201
0
}
202
203
ggml_tensor * llama_model_plamo2::graph::build_plamo2_attn_layer(llm_graph_input_attn_kv * inp,
204
                                                        ggml_tensor *             inp_pos,
205
                                                        ggml_tensor *             cur,
206
                                                        const llama_model &       model,
207
0
                                                        int                       il) {
208
    // self-attention
209
0
    {
210
        // PLaMo-2 uses combined QKV tensor
211
0
        ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur);
212
0
        cb(qkv, "wqkv", il);
213
214
        // split QKV tensor into Q, K, V
215
0
        const int64_t n_embd_head_q = hparams.n_embd_head_k();
216
0
        const int64_t n_embd_head_k = hparams.n_embd_head_k();
217
0
        const int64_t n_embd_head_v = hparams.n_embd_head_v();
218
0
        int32_t       n_head        = hparams.n_head(il);
219
0
        int32_t       n_head_kv     = hparams.n_head_kv(il);
220
221
0
        const int64_t q_offset = 0;
222
0
        const int64_t k_offset = n_embd_head_q * n_head;
223
0
        const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv;
224
225
0
        ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float),
226
0
                                          qkv->nb[1], q_offset * ggml_element_size(qkv));
227
0
        ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float),
228
0
                                          qkv->nb[1], k_offset * ggml_element_size(qkv));
229
0
        ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, n_embd_head_v, n_head_kv, n_tokens, n_embd_head_v * sizeof(float),
230
0
                                          qkv->nb[1], v_offset * ggml_element_size(qkv));
231
232
0
        cb(Qcur, "Qcur", il);
233
0
        cb(Kcur, "Kcur", il);
234
0
        cb(Vcur, "Vcur", il);
235
236
0
        Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
237
0
        cb(Qcur, "Qcur_normed", il);
238
239
0
        Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
240
0
                             ext_factor, attn_factor, beta_fast, beta_slow);
241
242
0
        Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
243
0
        cb(Kcur, "Kcur_normed", il);
244
245
0
        Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
246
0
                             ext_factor, attn_factor, beta_fast, beta_slow);
247
248
0
        cur = build_attn(inp,
249
0
            model.layers[il].wo, NULL, model.layers[il].wo_s,
250
0
            Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f / sqrtf(float(n_embd_head_v)), il);
251
0
    }
252
253
0
    cb(cur, "attn_out", il);
254
255
0
    return cur;
256
0
}
257
258
ggml_tensor * llama_model_plamo2::graph::build_plamo2_mamba_layer(llm_graph_input_rs * inp,
259
                                                         ggml_tensor *        cur,
260
                                                         const llama_model &  model,
261
                                                         const llama_ubatch & ubatch,
262
0
                                                         int                  il) {
263
0
    const auto * mctx_cur = inp->mctx;
264
265
0
    const auto kv_head = mctx_cur->get_head();
266
267
0
    const int64_t d_conv   = hparams.ssm_d_conv;
268
0
    const int64_t d_inner  = hparams.ssm_d_inner;
269
0
    const int64_t d_state  = hparams.ssm_d_state;
270
0
    const int64_t n_heads  = hparams.ssm_dt_rank;
271
0
    const int64_t head_dim = d_inner / n_heads;
272
0
    const int64_t n_group  = hparams.ssm_n_group;
273
0
    const int64_t n_seqs   = ubatch.n_seqs;
274
275
0
    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
276
277
0
    GGML_ASSERT(n_seqs != 0);
278
0
    GGML_ASSERT(ubatch.equal_seqs());
279
0
    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
280
0
    GGML_ASSERT(d_inner % n_heads == 0);
281
0
    GGML_ASSERT(n_group == 0);
282
283
0
    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
284
0
    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
285
286
0
    ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
287
0
    conv               = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs);
288
289
    // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
290
0
    cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
291
292
    // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
293
0
    ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur);
294
0
    cb(zx, "mamba_in_proj", il);
295
    // {8192, 5, 1, 1} -> {8192, 1, 5, 1}
296
0
    zx = ggml_permute(ctx0, zx, 0, 2, 1, 3);
297
0
    zx = ggml_cont_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs);
298
0
    cb(zx, "mamba_in_proj_out", il);
299
300
    // split into z and x
301
    // => {head_dim * n_heads, n_seq_tokens, n_seqs}
302
0
    ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3],
303
0
                                   head_dim * ggml_element_size(zx));
304
0
    x               = ggml_cont_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs);
305
    // x = ggml_permute(ctx0, x, 0, 2, 1, 3);
306
0
    cb(x, "mamba_x_split", il);
307
308
0
    ggml_tensor * z =
309
0
        ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0);
310
0
    cb(z, "mamba_z_split", il);
311
312
    // conv1d
313
0
    {
314
        // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
315
0
        ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
316
0
        cb(conv_x, "mamba_conv1d_input", il);
317
318
        // copy last (d_conv - 1) columns back into the state cache
319
0
        ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2],
320
0
                                               n_seq_tokens * (conv_x->nb[0]));
321
322
0
        ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv,
323
0
                                               ggml_view_1d(ctx0, conv_states_all,
324
0
                                                            (d_conv - 1) * (d_inner + 2 * n_group * d_state) * (n_seqs),
325
0
                                                            kv_head * (d_conv - 1) * (d_inner + 2 * n_group * d_state) *
326
0
                                                                ggml_element_size(conv_states_all))));
327
0
        cb(conv_states_all, "mamba_conv1d_state", il);
328
329
        // 1D convolution
330
0
        x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
331
0
        cb(x, "mamba_conv1d", il);
332
333
0
        x = ggml_silu(ctx0, x);
334
0
        cb(x, "mamba_conv1d_silu", il);
335
0
    }
336
337
    // SSM
338
0
    {
339
        // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
340
0
        ggml_tensor * x_bcdt = build_lora_mm(model.layers[il].ssm_x, x);
341
0
        cb(x_bcdt, "mamba_bcdt_proj", il);
342
343
        // split into dt, B, C
344
0
        const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
345
0
        ggml_tensor * B  = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0);
346
0
        ggml_tensor * C  = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2],
347
0
                                        ggml_element_size(x_bcdt) * d_state);
348
0
        ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2],
349
0
                                        ggml_element_size(x_bcdt) * (2 * d_state));
350
0
        cb(B, "mamba_B_raw", il);
351
0
        cb(C, "mamba_C_raw", il);
352
0
        cb(dt, "mamba_dt_raw", il);
353
354
        // Apply RMS norm to dt, B, C (PLaMo-2 specific)
355
0
        B  = build_norm(B, model.layers[il].ssm_b_norm, NULL, LLM_NORM_RMS, il);
356
0
        C  = build_norm(C, model.layers[il].ssm_c_norm, NULL, LLM_NORM_RMS, il);
357
0
        dt = build_norm(dt, model.layers[il].ssm_dt_norm, NULL, LLM_NORM_RMS, il);
358
0
        cb(B, "mamba_B_normed", il);
359
0
        cb(C, "mamba_C_normed", il);
360
0
        cb(dt, "mamba_dt_normed", il);
361
362
        // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
363
0
        dt = build_lora_mm(model.layers[il].ssm_dt, dt);
364
0
        dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
365
0
        cb(dt, "mamba_dt_proj", il);
366
367
0
        ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads);
368
0
        cb(A, "mamba_A", il);
369
370
0
        x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x),
371
0
                         head_dim * n_heads * ggml_element_size(x),
372
0
                         head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
373
0
        B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0);
374
0
        C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0);
375
376
        // use the states and the indices provided by build_recurrent_state
377
        // (this is necessary in order to properly use the states before they are overwritten,
378
        //  while avoiding to make unnecessary copies of the states)
379
0
        auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
380
0
            ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_heads, mctx_cur->get_size());
381
382
            // Custom operator to optimize the parallel associative scan
383
            // as described in the Annex D of the Mamba paper.
384
            // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
385
0
            return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
386
0
        };
387
388
0
        ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
389
0
        cb(y_ssm, "mamba_ssm_scan", il);
390
391
        // store last states
392
0
        ggml_build_forward_expand(
393
0
            gf, ggml_cpy(
394
0
                    ctx0,
395
0
                    ggml_view_1d(ctx0, y_ssm, n_heads * head_dim * d_state * n_seqs,
396
0
                                 n_heads * head_dim * n_seq_tokens * n_seqs * ggml_element_size(y_ssm)),
397
0
                    ggml_view_1d(ctx0, ssm_states_all, n_heads * head_dim * d_state * n_seqs,
398
0
                                 kv_head * n_seqs * n_heads * head_dim * d_state * ggml_element_size(ssm_states_all))));
399
0
        cb(ssm_states_all, "mamba_ssm_states", il);
400
401
0
        ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs,
402
0
                                       head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x),
403
0
                                       head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
404
0
        cb(y, "mamba_y_view", il);
405
406
        // Add D parameter and apply gating with z
407
        // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs}
408
0
        ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads);
409
0
        y               = ggml_add(ctx0, y, ggml_mul(ctx0, x, D));
410
0
        cb(y, "mamba_y_add_d", il);
411
412
0
        y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
413
0
        cb(y, "mamba_y_swiglu_z", il);
414
415
        // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
416
0
        y   = ggml_view_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs, y->nb[2], y->nb[3], 0);
417
0
        cur = build_lora_mm(model.layers[il].ssm_out, y);
418
0
        cb(cur, "mamba_out_proj", il);
419
0
    }
420
421
    // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
422
0
    cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
423
0
    cb(cur, "mamba_out", il);
424
425
0
    return cur;
426
0
}