Coverage Report

Created: 2026-06-13 06:24

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/models/delta-net-base.cpp
Line
Count
Source
1
#include "models.h"
2
3
#include "llama-impl.h"
4
#include "llama-memory-recurrent.h"
5
6
// utility to get one slice from the third dimension
7
// input dim:  [x, y, c, b]
8
// output dim: [x, y, 1, b]
9
0
static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
10
0
    return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
11
0
        t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
12
0
}
13
14
0
llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {}
15
16
std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_chunking(
17
        ggml_tensor * q,
18
        ggml_tensor * k,
19
        ggml_tensor * v,
20
        ggml_tensor * g,
21
        ggml_tensor * b,
22
        ggml_tensor * s,
23
0
        int           il) {
24
0
    const int64_t S_k      = q->ne[0];
25
0
    const int64_t H_k      = q->ne[1];
26
0
    const int64_t n_tokens = q->ne[2];
27
0
    const int64_t n_seqs   = q->ne[3];
28
29
0
    const int64_t S_v = v->ne[0];
30
0
    const int64_t H_v = v->ne[1];
31
0
    const bool kda = (g->ne[0] == S_k && g->ne[1] == H_k);
32
33
0
    GGML_ASSERT(S_k == S_v);
34
0
    GGML_ASSERT(H_v % H_k == 0);
35
36
0
    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
37
0
    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
38
0
    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
39
40
0
    GGML_ASSERT(g->ne[0] == 1   || g->ne[0] == S_v);
41
0
    GGML_ASSERT(                   g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
42
0
    GGML_ASSERT(b->ne[0] == 1   && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
43
0
    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v      && s->ne[3] == n_seqs);
44
45
0
    const float scale = 1.0f / sqrtf(S_k);
46
47
0
    q = ggml_scale(ctx0, q, scale);
48
49
0
    cb(q, "q_in", il);
50
0
    cb(k, "k_in", il);
51
0
    cb(v, "v_in", il);
52
0
    cb(b, "b_in", il);
53
0
    cb(g, "g_in", il);
54
55
0
    q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
56
0
    k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
57
0
    v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
58
0
    g = ggml_permute(ctx0, g, 0, 2, 1, 3); // [g_0, n_tokens, H_v, n_seqs]
59
0
    b = ggml_permute(ctx0, b, 0, 2, 1, 3); // [  1, n_tokens, H_v, n_seqs]
60
61
0
    const int CS = kda ? 16 : 64; // chunk size
62
63
0
    const int pad = (CS - n_tokens % CS) % CS;
64
0
    const int n_chunks = (n_tokens + pad) / CS;
65
66
0
    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
67
0
    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
68
0
    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
69
0
    g = ggml_pad(ctx0, g, 0, pad, 0, 0);
70
0
    b = ggml_pad(ctx0, b, 0, pad, 0, 0);
71
72
0
    ggml_tensor * v_b = ggml_mul(ctx0, v, b);
73
0
    ggml_tensor * k_b = ggml_mul(ctx0, k, b);
74
75
0
    cb(v_b, "v_b", il);
76
0
    cb(k_b, "k_b", il);
77
78
0
    q   = ggml_reshape_4d(ctx0, q,   S_k, CS, n_chunks, H_k * n_seqs);
79
0
    k   = ggml_reshape_4d(ctx0, k,   S_k, CS, n_chunks, H_k * n_seqs);
80
0
    k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs);
81
0
    v   = ggml_reshape_4d(ctx0, v,   S_v, CS, n_chunks, H_v * n_seqs);
82
0
    v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs);
83
84
0
    g = ggml_reshape_4d(ctx0, g, g->ne[0], CS, n_chunks, H_v * n_seqs);
85
0
    b = ggml_reshape_4d(ctx0, b, 1,        CS, n_chunks, H_v * n_seqs);
86
87
    // [CS, g_0, n_chunks, H_v * n_seqs]
88
    // TODO: extend ggml_cumsum with axis parameter to avoid transpose
89
0
    ggml_tensor * g_cs = ggml_cumsum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, g)));
90
0
    cb(g_cs, "g_cs", il);
91
92
0
    ggml_tensor * kb = nullptr;
93
0
    ggml_tensor * kq = nullptr;
94
0
    if (kda) {
95
0
        const int64_t CHB = n_chunks * H_k * n_seqs;
96
97
0
        ggml_tensor * g_cs_i = ggml_reshape_4d(ctx0, g_cs, CS, 1, S_k, CHB);  // [chunk_size, 1, S_k, CHB]
98
0
        ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, S_k, CHB);  // [1, chunk_size, S_k, CHB]
99
100
0
        g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, S_k, CHB);  // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
101
102
        // decay_mask [chunk_size,chunk_size,S_k,CHB]
103
0
        ggml_tensor * decay_mask;
104
0
        decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
105
0
        decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
106
0
        decay_mask = ggml_exp(ctx0, decay_mask);
107
0
        cb(decay_mask, "decay_mask", il);
108
109
        // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
110
0
        decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, CS, CS, CHB);
111
112
0
        ggml_tensor * k_b_i = ggml_reshape_4d(ctx0, k_b, S_k, CS,  1, CHB);
113
0
        ggml_tensor * k_j   = ggml_reshape_4d(ctx0, k,   S_k,  1, CS, CHB);
114
0
        ggml_tensor * q_i   = ggml_reshape_4d(ctx0, q,   S_k, CS,  1, CHB);
115
116
0
        ggml_tensor * decay_k_b_i = ggml_mul(ctx0, decay_mask, k_b_i);
117
0
        ggml_tensor * decay_q_i   = ggml_mul(ctx0, decay_mask, q_i);
118
119
        // decay_k_b_i [S,BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
120
0
        kb = ggml_mul_mat(ctx0, decay_k_b_i, k_j);
121
0
        kq = ggml_mul_mat(ctx0, decay_q_i,   k_j);
122
123
0
        kb = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kb, CS, CS, n_chunks, H_v * n_seqs)));
124
0
        kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kq, CS, CS, n_chunks, H_v * n_seqs)));
125
0
    } else {
126
0
        ggml_tensor * g_cs_i = g_cs;
127
0
        ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
128
129
0
        g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
130
131
        // [CS, CS, n_chunks, H_v * n_seqs]
132
0
        ggml_tensor * decay_mask;
133
0
        decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
134
0
        decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
135
0
        decay_mask = ggml_exp(ctx0, decay_mask);
136
0
        cb(decay_mask, "decay_mask", il);
137
138
        // [CS, CS, n_chunks, H_k * n_seqs]
139
0
        kb = ggml_mul_mat(ctx0, k,  k_b);
140
0
        kb = ggml_mul    (ctx0, kb, decay_mask);
141
142
        // [CS, CS, n_chunks, H_k * n_seqs]
143
0
        kq = ggml_mul_mat(ctx0, k, q);
144
0
        kq = ggml_mul(ctx0, kq, decay_mask);
145
0
    }
146
147
0
    kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
148
0
    cb(kq, "kq", il);
149
150
    // [CS, CS, n_chunks, H_k * n_seqs]
151
0
    ggml_tensor * attn;
152
0
    attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER);
153
0
    cb(attn, "attn", il);
154
155
0
    ggml_tensor * identity;
156
0
    identity = ggml_view_1d(ctx0, attn, CS, 0);
157
0
    identity = ggml_fill   (ctx0, identity, 1.0f);
158
0
    identity = ggml_diag   (ctx0, identity);
159
160
0
    ggml_tensor * lhs = ggml_add(ctx0, attn, identity);
161
0
    cb(lhs, "dnet_add_ch_lhs", il);
162
163
0
    attn = ggml_neg(ctx0, attn);
164
0
    cb(attn, "attn_pre_solve", il);
165
166
0
    ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
167
0
    attn = ggml_add(ctx0, lin_solve, identity);
168
0
    cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs]
169
170
    // [S_v, CS, n_chunks, H_v * n_seqs]
171
0
    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn);
172
173
    // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs]
174
0
    ggml_tensor * g_exp = ggml_exp(ctx0, g_cs);
175
176
0
    k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b));
177
178
    // [CS, S_k, n_chunks, H_k * n_seqs]
179
0
    ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp);
180
0
    cb(kbg, "k_beta_g_exp", il);
181
182
    // [S_k, CS, n_chunks, H_k * n_seqs]
183
0
    ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn);
184
0
    cb(k_cd, "k_cumdecay", il);
185
186
    // [1, CS, n_chunks, H_k * n_seqs] KDA: [S_k, CS, n_chunks, H_k * n_seqs]
187
0
    ggml_tensor * g_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_exp));
188
0
    ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t);
189
190
    // vectorized calculation of key_gdiff
191
    // improved from the chunked version:
192
    //   g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
193
    //   g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
194
    //   key_gdiff = key * g_diff.unsqueeze(-1)
195
    //   kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
196
    //   last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
197
198
    // get last element in g_cumsum along CS dimension (ne0)
199
    // example: [[x, y, z, ..., last], ...] -> [[last], ...]
200
    // [1, 1, n_chunks, H_v * n_seqs] KDA: [1, S_k, n_chunks, H_v * n_seqs]
201
0
    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, g_cs->ne[1], g_cs->ne[2], g_cs->ne[3],
202
0
            g_cs->nb[1],
203
0
            g_cs->nb[2],
204
0
            g_cs->nb[3],
205
0
            ggml_row_size(g_cs->type, g_cs->ne[0] - 1));
206
0
    cb(g_last, "g_last", il);
207
208
    // TODO: remove this cont when CUDA supports non-cont unary ops
209
0
    g_last = ggml_cont(ctx0, g_last);
210
211
    // [1, 1, n_chunks, H_v * n_seqs] KDA: [S_k, 1, n_chunks, H_v * n_seqs]
212
0
    ggml_tensor * g_last_exp_t = ggml_transpose(ctx0, ggml_exp(ctx0, g_last));
213
0
    cb(g_last_exp_t, "g_last_exp_t", il);
214
215
    // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs]
216
0
    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last));
217
0
    cb(g_diff, "g_diff", il);
218
219
0
    ggml_tensor * g_diff_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_exp(ctx0, g_diff)));
220
221
    // [S_k, CS, n_chunks, H_v * n_seqs]
222
0
    ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t);
223
0
    cb(kg, "key_gdiff", il);
224
225
    // [CS, S_k, n_chunks, H_v * n_seqs]
226
0
    ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
227
0
    cb(kg_t, "key_gdiff_t", il);
228
229
0
    s = ggml_reshape_4d(ctx0, s, S_v, S_v, 1, H_v * n_seqs);
230
0
    cb(s, "dnet_add_ch_state", il);
231
232
    // [CS, S_v, n_chunks, H_v * n_seqs]
233
0
    ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
234
235
0
    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
236
0
        ggml_tensor * ch_k_cd    = get_slice_2d(ctx0, k_cd,    chunk); // [S_k,  CS, 1, H_k * n_seqs]
237
0
        ggml_tensor * ch_v_t     = get_slice_2d(ctx0, v_t,     chunk); // [ CS, S_v, 1, H_v * n_seqs]
238
0
        ggml_tensor * ch_kq      = get_slice_2d(ctx0, kq,      chunk); // [ CS,  CS, 1, H_k * n_seqs]
239
0
        ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k,  CS, 1, H_k * n_seqs]
240
0
        ggml_tensor * ch_kg_t    = get_slice_2d(ctx0, kg_t,    chunk); // [ CS, S_k, 1, H_v * n_seqs]
241
242
        // [CS, S_v, 1, H_v * n_seqs]
243
0
        ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s);
244
0
        cb(v_t_p, "v_prime", il);
245
246
        // [CS, S_v, 1, H_v * n_seqs]
247
0
        ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p);
248
0
        cb(v_t_new, "v_t_new", il);
249
250
        // [S_v, CS, 1, H_v * n_seqs]
251
0
        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq);
252
0
        cb(v_attn, "v_attn", il);
253
254
        // [S_v, CS, 1, H_v * n_seqs]
255
0
        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s, ch_q_g_exp);
256
0
        cb(attn_inter, "attn_inter", il);
257
258
        // [S_v, CS, 1, H_v * n_seqs]
259
0
        ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn);
260
0
        cb(o_ch, "dnet_add_ch_attn_out", il);
261
262
0
        v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]);
263
264
        // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
265
        // TODO: head broadcast might not work here - probably will need a transpose
266
0
        ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs]
267
268
        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
269
0
        ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk);
270
271
0
        s = ggml_mul(ctx0, s, ch_g_last_exp_t);
272
0
        s = ggml_add(ctx0, s, kgv);
273
0
        cb(s, "dnet_add_ch_state", il);
274
0
    }
275
276
    // truncate padded tokens
277
0
    ggml_tensor * o = ggml_view_4d(ctx0, v,
278
0
            S_v, n_tokens, H_v, n_seqs,
279
0
            ggml_row_size(v->type, S_v),
280
0
            ggml_row_size(v->type, S_v * CS * n_chunks),
281
0
            ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
282
0
    o = ggml_permute  (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
283
0
    s = ggml_reshape_4d(ctx0, s, S_v, S_v, H_v, n_seqs);
284
0
    cb(s, "output_state", il);
285
286
0
    return {o, s};
287
0
}
288
289
std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_autoregressive(
290
        ggml_tensor * q,
291
        ggml_tensor * k,
292
        ggml_tensor * v,
293
        ggml_tensor * g,
294
        ggml_tensor * b, // beta
295
        ggml_tensor * s, // state
296
0
        int           il) {
297
0
    const int64_t S_k      = q->ne[0];
298
0
    const int64_t H_k      = q->ne[1];
299
0
    const int64_t n_tokens = q->ne[2];
300
0
    const int64_t n_seqs   = q->ne[3];
301
302
0
    const int64_t S_v = v->ne[0];
303
0
    const int64_t H_v = v->ne[1];
304
305
0
    GGML_ASSERT(n_tokens == 1);
306
307
0
    GGML_ASSERT(S_k == S_v);
308
0
    GGML_ASSERT(H_v % H_k == 0);
309
310
0
    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
311
0
    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
312
0
    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
313
314
0
    GGML_ASSERT(g->ne[0] == 1   || g->ne[0] == S_v);
315
0
    GGML_ASSERT(                   g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
316
0
    GGML_ASSERT(b->ne[0] == 1   && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
317
0
    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v      && s->ne[3] == n_seqs);
318
319
0
    const float scale = 1.0f / sqrtf(S_k);
320
321
0
    q = ggml_scale(ctx0, q, scale);
322
323
0
    q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
324
0
    k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
325
0
    v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
326
327
0
    cb(q, "q_in", il);
328
0
    cb(k, "k_in", il);
329
0
    cb(v, "v_in", il);
330
0
    cb(b, "b_in", il);
331
0
    cb(g, "g_in", il);
332
333
    // GDA: [1,  1,  H_v, n_seqs]
334
    // KDA: [1, S_k, H_v, n_seqs]
335
0
    g = ggml_reshape_4d(ctx0, g, 1, g->ne[0], H_v, n_seqs);
336
0
    b = ggml_reshape_4d(ctx0, b, 1,        1, H_v, n_seqs);
337
338
    // [S_v, S_v, H_v, n_seqs]
339
0
    g = ggml_exp(ctx0, g);
340
0
    s = ggml_mul(ctx0, s, g);
341
342
    // [1, S_v, H_v, n_seqs]
343
0
    ggml_tensor * sk;
344
0
    sk = ggml_mul     (ctx0, s, k);
345
0
    sk = ggml_sum_rows(ctx0, sk);
346
347
    // [S_v, 1, H_v, n_seqs]
348
0
    ggml_tensor * d;
349
0
    d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk));
350
0
    d = ggml_mul(ctx0, d, b);
351
352
    // [1, S_v, H_v, n_seqs]
353
0
    ggml_tensor * d_t;
354
0
    d_t = ggml_transpose(ctx0, d);
355
356
    // [S_v, S_v, H_v, n_seqs]
357
0
    ggml_tensor * kd;
358
0
    k  = ggml_repeat(ctx0, k, s);
359
0
    kd = ggml_mul   (ctx0, k, d_t);
360
361
0
    s = ggml_add(ctx0, s, kd);
362
363
0
    cb(s, "dnet_add_ar_state", il);
364
365
0
    ggml_tensor * s_q = ggml_mul     (ctx0, s, q);
366
0
    ggml_tensor * o   = ggml_sum_rows(ctx0, s_q);
367
368
0
    o = ggml_permute  (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
369
370
0
    return {o, s};
371
0
}
372
373
std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_fused(
374
        ggml_tensor * q,
375
        ggml_tensor * k,
376
        ggml_tensor * v,
377
        ggml_tensor * g,
378
        ggml_tensor * b,
379
        ggml_tensor * s,
380
0
        int           il) {
381
0
    const int64_t S_k      = q->ne[0];
382
0
    const int64_t H_k      = q->ne[1];
383
0
    const int64_t n_tokens = q->ne[2];
384
0
    const int64_t n_seqs   = q->ne[3];
385
386
0
    const int64_t S_v = v->ne[0];
387
0
    const int64_t H_v = v->ne[1];
388
389
0
    GGML_ASSERT(S_k == S_v);
390
0
    GGML_ASSERT(H_v % H_k == 0);
391
392
0
    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
393
0
    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
394
0
    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
395
396
0
    GGML_ASSERT(g->ne[0] == 1   || g->ne[0] == S_v);
397
0
    GGML_ASSERT(                   g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
398
0
    GGML_ASSERT(b->ne[0] == 1   && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
399
0
    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v      && s->ne[3] == n_seqs);
400
401
    // K=1: output carries the final state only. state s is 4D [S_v, S_v, H_v, n_seqs].
402
0
    ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, /*K=*/1);
403
0
    if (n_tokens == 1) {
404
0
        cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il);
405
0
    } else {
406
0
        cb(result, LLAMA_TENSOR_NAME_FGDN_CH, il);
407
0
    }
408
409
0
    ggml_tensor * output = ggml_view_4d(ctx0, result,
410
0
            S_v, H_v, n_tokens, n_seqs,
411
0
            ggml_row_size(result->type, S_v),
412
0
            ggml_row_size(result->type, S_v * H_v),
413
0
            ggml_row_size(result->type, S_v * H_v * n_tokens), 0);
414
415
0
    ggml_tensor * new_state = ggml_view_4d(ctx0, result,
416
0
            S_v, S_v, H_v, n_seqs,
417
0
            ggml_row_size(result->type, S_v),
418
0
            ggml_row_size(result->type, S_v * S_v),
419
0
            ggml_row_size(result->type, S_v * S_v * H_v),
420
0
            ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs));
421
422
0
    return {output, new_state};
423
0
}
424
425
std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net(
426
        ggml_tensor * q,
427
        ggml_tensor * k,
428
        ggml_tensor * v,
429
        ggml_tensor * g,
430
        ggml_tensor * b,
431
        ggml_tensor * s,
432
0
        int           il) {
433
0
    const int64_t n_seq_tokens = q->ne[2];
434
435
0
    if (n_seq_tokens == 1) {
436
0
        if (cparams.fused_gdn_ar) {
437
0
            return build_delta_net_fused(q, k, v, g, b, s, il);
438
0
        }
439
0
        return build_delta_net_autoregressive(q, k, v, g, b, s, il);
440
0
    }
441
442
0
    if (cparams.fused_gdn_ch) {
443
0
        return build_delta_net_fused(q, k, v, g, b, s, il);
444
0
    }
445
446
0
    return build_delta_net_chunking(q, k, v, g, b, s, il);
447
0
}
448
449
ggml_tensor * llm_build_delta_net_base::build_conv_state(
450
        llm_graph_input_rs * inp,
451
        ggml_tensor *        conv_states_all,
452
        ggml_tensor *        qkv_mixed,
453
        int64_t              conv_kernel_size,
454
        int64_t              conv_channels,
455
0
        int                  il) {
456
0
    const auto * mctx_cur = inp->mctx;
457
458
0
    const auto kv_head  = mctx_cur->get_head();
459
0
    const auto mem_size = mctx_cur->get_size();
460
461
0
    const int64_t n_seqs = ubatch.n_seqs;
462
463
0
    ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
464
0
    cb(conv_states, "conv_states", il);
465
466
0
    conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
467
0
    cb(conv_states, "conv_states_reshaped", il);
468
469
0
    qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
470
0
    cb(qkv_mixed, "qkv_mixed_transposed", il);
471
472
0
    ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
473
0
    cb(conv_input, "conv_input", il);
474
475
0
    const int64_t row_count = (conv_kernel_size - 1) * conv_channels;
476
477
0
    const size_t row_size  = ggml_row_size(conv_states_all->type, row_count);
478
479
0
    if (cparams.n_rs_seq == 0) {
480
0
        const int64_t s_idx  = conv_input->ne[0] - conv_states->ne[0];
481
0
        const int64_t s_slot = 0;
482
483
0
        ggml_tensor * conv_state_last =
484
0
            ggml_view_3d(ctx0, conv_input,
485
0
                    conv_kernel_size - 1, conv_channels, n_seqs,
486
0
                    conv_input->nb[1], conv_input->nb[2],
487
0
                    ggml_row_size(conv_input->type, s_idx));
488
0
        cb(conv_state_last, "conv_state_last", il);
489
490
0
        ggml_tensor * conv_state_update =
491
0
            ggml_view_2d(ctx0, conv_states_all,
492
0
                    row_count, n_seqs, conv_states_all->nb[1],
493
0
                    (s_slot * mem_size + kv_head) * row_size);
494
0
        cb(conv_state_update, "conv_state_update", il);
495
496
0
        ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update));
497
0
    } else {
498
        // [TAG_RECURRENT_ROLLBACK_SPLITS]
499
        // TODO: this logic incorrectly assumes that the last (n_rs_seq + 1) tokens of a sequence in a batch are
500
        //       inside the same ubatch. currently with `split_equal()` this is not correct
501
502
0
        const int64_t K = (int64_t) cparams.n_rs_seq + 1;
503
504
0
        for (int64_t t = 1; t <= K; ++t) {
505
0
            const int64_t s_idx  = std::max<int64_t>(0, conv_input->ne[0] - conv_states->ne[0] - K + t);
506
0
            const int64_t s_slot = K - t;
507
508
0
            ggml_tensor * conv_state_last =
509
0
                ggml_view_3d(ctx0, conv_input,
510
0
                        conv_kernel_size - 1, conv_channels, n_seqs,
511
0
                        conv_input->nb[1], conv_input->nb[2],
512
0
                        ggml_row_size(conv_input->type, s_idx));
513
514
0
            ggml_tensor * conv_state_update =
515
0
                ggml_view_2d(ctx0,
516
0
                        conv_states_all, row_count, n_seqs,
517
0
                        conv_states_all->nb[1],
518
0
                        (s_slot * mem_size + kv_head) * row_size);
519
520
0
            ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update));
521
0
        }
522
0
    }
523
524
0
    return conv_input;
525
0
}
526
527
ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
528
        llm_graph_input_rs * inp,
529
        ggml_tensor *        ssm_states_all,
530
        ggml_tensor *        q,
531
        ggml_tensor *        k,
532
        ggml_tensor *        v,
533
        ggml_tensor *        g,
534
        ggml_tensor *        b,
535
        ggml_tensor *        s,
536
0
        int                  il) {
537
0
    const auto * mctx_cur   = inp->mctx;
538
0
    const auto   kv_head    = mctx_cur->get_head();
539
0
    const uint32_t mem_size = mctx_cur->get_size();
540
541
0
    const int64_t S_v          = s->ne[0];
542
0
    const int64_t H_v          = s->ne[2];
543
0
    const int64_t n_seqs       = s->ne[3];
544
0
    const int64_t n_seq_tokens = q->ne[2];
545
546
0
    const bool keep = cparams.n_rs_seq > 0;
547
548
0
    if (!keep) {
549
0
        auto attn_out = build_delta_net(q, k, v, g, b, s, il);
550
0
        ggml_tensor * output    = attn_out.first;
551
0
        ggml_tensor * new_state = attn_out.second;
552
0
        cb(output, "attn_output", il);
553
0
        cb(new_state, "new_state", il);
554
555
0
        ggml_build_forward_expand(gf,
556
0
                ggml_cpy(ctx0, new_state,
557
0
                    ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
558
0
                        kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
559
560
0
        return output;
561
0
    }
562
563
0
    const int64_t D = S_v * S_v * H_v;
564
0
    const int64_t K = cparams.n_rs_seq + 1;
565
566
    // state s is 4D [S_v, S_v, H_v, n_seqs]; K snapshot slots are written into the output.
567
0
    ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, K);
568
0
    if (n_seq_tokens > 1) {
569
0
        cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il);
570
0
    } else {
571
0
        cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_AR, il);
572
0
    }
573
574
0
    const int64_t attn_score_elems    = S_v * H_v * n_seq_tokens * n_seqs;
575
0
    const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs;
576
577
0
    ggml_tensor * output = ggml_view_4d(ctx0, gdn_out,
578
0
        S_v, H_v, n_seq_tokens, n_seqs,
579
0
        ggml_row_size(gdn_out->type, S_v),
580
0
        ggml_row_size(gdn_out->type, S_v * H_v),
581
0
        ggml_row_size(gdn_out->type, S_v * H_v * n_seq_tokens),
582
0
        0);
583
0
    cb(output, "attn_output", il);
584
585
0
    const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all);
586
587
    // op writes the last min(n_seq_tokens, K) snapshots; trailing slots are left unwritten
588
0
    const int64_t n_written = std::min<int64_t>(n_seq_tokens, K);
589
590
    // write the produced snapshots into the recurrent cache (snapshot slot i -> rollback group i)
591
0
    ggml_tensor * src = ggml_view_3d(ctx0, gdn_out,
592
0
        D, n_seqs, n_written,
593
0
        ggml_row_size(gdn_out->type, D),
594
0
        ggml_row_size(gdn_out->type, state_size_per_snap),
595
0
        ggml_row_size(gdn_out->type, attn_score_elems));
596
597
0
    ggml_tensor * dst = ggml_view_3d(ctx0, ssm_states_all,
598
0
        D, n_seqs, n_written,
599
0
        ssm_states_all->nb[1],
600
0
        (size_t) mem_size * row_size,
601
0
        (size_t) kv_head * row_size);
602
603
0
    ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
604
605
0
    return output;
606
0
}