Coverage Report

Created: 2026-03-07 06:35

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/llama.cpp/src/models/delta-net-base.cpp
Line
Count
Source
1
#include "models.h"
2
3
// utility to get one slice from the third dimension
4
// input dim:  [x, y, c, b]
5
// output dim: [x, y, 1, b]
6
0
static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
7
0
    return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
8
0
        t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
9
0
}
10
11
0
llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {}
12
13
std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_chunking(
14
        ggml_tensor * q,
15
        ggml_tensor * k,
16
        ggml_tensor * v,
17
        ggml_tensor * g,
18
        ggml_tensor * b,
19
        ggml_tensor * s,
20
0
        int           il) {
21
0
    const int64_t S_k      = q->ne[0];
22
0
    const int64_t H_k      = q->ne[1];
23
0
    const int64_t n_tokens = q->ne[2];
24
0
    const int64_t n_seqs   = q->ne[3];
25
26
0
    const int64_t S_v = v->ne[0];
27
0
    const int64_t H_v = v->ne[1];
28
0
    const bool kda = (g->ne[0] == S_k && g->ne[1] == H_k);
29
30
0
    GGML_ASSERT(S_k == S_v);
31
0
    GGML_ASSERT(H_v % H_k == 0);
32
33
0
    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
34
0
    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
35
0
    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
36
37
0
    GGML_ASSERT(g->ne[0] == 1   || g->ne[0] == S_v);
38
0
    GGML_ASSERT(                   g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
39
0
    GGML_ASSERT(b->ne[0] == 1   && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
40
0
    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v      && s->ne[3] == n_seqs);
41
42
0
    const float scale = 1.0f / sqrtf(S_k);
43
44
0
    q = ggml_scale(ctx0, q, scale);
45
46
0
    cb(q, "q_in", il);
47
0
    cb(k, "k_in", il);
48
0
    cb(v, "v_in", il);
49
0
    cb(b, "b_in", il);
50
0
    cb(g, "g_in", il);
51
52
0
    q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
53
0
    k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
54
0
    v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
55
0
    g = ggml_permute(ctx0, g, 0, 2, 1, 3); // [g_0, n_tokens, H_v, n_seqs]
56
0
    b = ggml_permute(ctx0, b, 0, 2, 1, 3); // [  1, n_tokens, H_v, n_seqs]
57
58
0
    const int CS = kda ? 16 : 64; // chunk size
59
60
0
    const int pad = (CS - n_tokens % CS) % CS;
61
0
    const int n_chunks = (n_tokens + pad) / CS;
62
63
0
    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
64
0
    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
65
0
    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
66
0
    g = ggml_pad(ctx0, g, 0, pad, 0, 0);
67
0
    b = ggml_pad(ctx0, b, 0, pad, 0, 0);
68
69
0
    ggml_tensor * v_b = ggml_mul(ctx0, v, b);
70
0
    ggml_tensor * k_b = ggml_mul(ctx0, k, b);
71
72
0
    cb(v_b, "v_b", il);
73
0
    cb(k_b, "k_b", il);
74
75
0
    q   = ggml_reshape_4d(ctx0, q,   S_k, CS, n_chunks, H_k * n_seqs);
76
0
    k   = ggml_reshape_4d(ctx0, k,   S_k, CS, n_chunks, H_k * n_seqs);
77
0
    k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs);
78
0
    v   = ggml_reshape_4d(ctx0, v,   S_v, CS, n_chunks, H_v * n_seqs);
79
0
    v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs);
80
81
0
    g = ggml_reshape_4d(ctx0, g, g->ne[0], CS, n_chunks, H_v * n_seqs);
82
0
    b = ggml_reshape_4d(ctx0, b, 1,        CS, n_chunks, H_v * n_seqs);
83
84
    // [CS, g_0, n_chunks, H_v * n_seqs]
85
    // TODO: extend ggml_cumsum with axis parameter to avoid transpose
86
0
    ggml_tensor * g_cs = ggml_cumsum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, g)));
87
0
    cb(g_cs, "g_cs", il);
88
89
0
    ggml_tensor * kb = nullptr;
90
0
    ggml_tensor * kq = nullptr;
91
0
    if (kda) {
92
0
        const int64_t CHB = n_chunks * H_k * n_seqs;
93
94
0
        ggml_tensor * g_cs_i = ggml_reshape_4d(ctx0, g_cs, CS, 1, S_k, CHB);  // [chunk_size, 1, S_k, CHB]
95
0
        ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, S_k, CHB);  // [1, chunk_size, S_k, CHB]
96
97
0
        g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, S_k, CHB);  // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
98
99
        // decay_mask [chunk_size,chunk_size,S_k,CHB]
100
0
        ggml_tensor * decay_mask;
101
0
        decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
102
0
        decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
103
0
        decay_mask = ggml_exp(ctx0, decay_mask);
104
0
        cb(decay_mask, "decay_mask", il);
105
106
        // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
107
0
        decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, CS, CS, CHB);
108
109
0
        ggml_tensor * k_b_i = ggml_reshape_4d(ctx0, k_b, S_k, CS,  1, CHB);
110
0
        ggml_tensor * k_j   = ggml_reshape_4d(ctx0, k,   S_k,  1, CS, CHB);
111
0
        ggml_tensor * q_i   = ggml_reshape_4d(ctx0, q,   S_k, CS,  1, CHB);
112
113
0
        ggml_tensor * decay_k_b_i = ggml_mul(ctx0, decay_mask, k_b_i);
114
0
        ggml_tensor * decay_q_i   = ggml_mul(ctx0, decay_mask, q_i);
115
116
        // decay_k_b_i [S,BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
117
0
        kb = ggml_mul_mat(ctx0, decay_k_b_i, k_j);
118
0
        kq = ggml_mul_mat(ctx0, decay_q_i,   k_j);
119
120
0
        kb = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kb, CS, CS, n_chunks, H_v * n_seqs)));
121
0
        kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kq, CS, CS, n_chunks, H_v * n_seqs)));
122
0
    } else {
123
0
        ggml_tensor * g_cs_i = g_cs;
124
0
        ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
125
126
0
        g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
127
128
        // [CS, CS, n_chunks, H_v * n_seqs]
129
0
        ggml_tensor * decay_mask;
130
0
        decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
131
0
        decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
132
0
        decay_mask = ggml_exp(ctx0, decay_mask);
133
0
        cb(decay_mask, "decay_mask", il);
134
135
        // [CS, CS, n_chunks, H_k * n_seqs]
136
0
        kb = ggml_mul_mat(ctx0, k,  k_b);
137
0
        kb = ggml_mul    (ctx0, kb, decay_mask);
138
139
        // [CS, CS, n_chunks, H_k * n_seqs]
140
0
        kq = ggml_mul_mat(ctx0, k, q);
141
0
        kq = ggml_mul(ctx0, kq, decay_mask);
142
0
    }
143
144
0
    kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
145
0
    cb(kq, "kq", il);
146
147
    // [CS, CS, n_chunks, H_k * n_seqs]
148
0
    ggml_tensor * attn;
149
0
    attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER);
150
0
    cb(attn, "attn", il);
151
152
0
    ggml_tensor * identity;
153
0
    identity = ggml_view_1d(ctx0, attn, CS, 0);
154
0
    identity = ggml_fill   (ctx0, identity, 1.0f);
155
0
    identity = ggml_diag   (ctx0, identity);
156
157
0
    ggml_tensor * lhs = ggml_add(ctx0, attn, identity);
158
0
    cb(lhs, "dnet_add_ch_lhs", il);
159
160
0
    attn = ggml_neg(ctx0, attn);
161
0
    cb(attn, "attn_pre_solve", il);
162
163
0
    ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
164
0
    attn = ggml_add(ctx0, lin_solve, identity);
165
0
    cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs]
166
167
    // [S_v, CS, n_chunks, H_v * n_seqs]
168
0
    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn);
169
170
    // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs]
171
0
    ggml_tensor * g_exp = ggml_exp(ctx0, g_cs);
172
173
0
    k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b));
174
175
    // [CS, S_k, n_chunks, H_k * n_seqs]
176
0
    ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp);
177
0
    cb(kbg, "k_beta_g_exp", il);
178
179
    // [S_k, CS, n_chunks, H_k * n_seqs]
180
0
    ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn);
181
0
    cb(k_cd, "k_cumdecay", il);
182
183
    // [1, CS, n_chunks, H_k * n_seqs] KDA: [S_k, CS, n_chunks, H_k * n_seqs]
184
0
    ggml_tensor * g_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_exp));
185
0
    ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t);
186
187
    // vectorized calculation of key_gdiff
188
    // improved from the chunked version:
189
    //   g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
190
    //   g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
191
    //   key_gdiff = key * g_diff.unsqueeze(-1)
192
    //   kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
193
    //   last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
194
195
    // get last element in g_cumsum along CS dimension (ne0)
196
    // example: [[x, y, z, ..., last], ...] -> [[last], ...]
197
    // [1, 1, n_chunks, H_v * n_seqs] KDA: [1, S_k, n_chunks, H_v * n_seqs]
198
0
    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, g_cs->ne[1], g_cs->ne[2], g_cs->ne[3],
199
0
            g_cs->nb[1],
200
0
            g_cs->nb[2],
201
0
            g_cs->nb[3],
202
0
            ggml_row_size(g_cs->type, g_cs->ne[0] - 1));
203
0
    cb(g_last, "g_last", il);
204
205
    // TODO: remove this cont when CUDA supports non-cont unary ops
206
0
    g_last = ggml_cont(ctx0, g_last);
207
208
    // [1, 1, n_chunks, H_v * n_seqs] KDA: [S_k, 1, n_chunks, H_v * n_seqs]
209
0
    ggml_tensor * g_last_exp_t = ggml_transpose(ctx0, ggml_exp(ctx0, g_last));
210
0
    cb(g_last_exp_t, "g_last_exp_t", il);
211
212
    // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs]
213
0
    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last));
214
0
    cb(g_diff, "g_diff", il);
215
216
0
    ggml_tensor * g_diff_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_exp(ctx0, g_diff)));
217
218
    // [S_k, CS, n_chunks, H_v * n_seqs]
219
0
    ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t);
220
0
    cb(kg, "key_gdiff", il);
221
222
    // [CS, S_k, n_chunks, H_v * n_seqs]
223
0
    ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
224
0
    cb(kg_t, "key_gdiff_t", il);
225
226
0
    ggml_tensor * s_t = ggml_transpose(ctx0, s);
227
0
    s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
228
0
    cb(s_t, "dnet_add_ch_state", il);
229
230
    // [CS, S_v, n_chunks, H_v * n_seqs]
231
0
    ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
232
233
0
    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
234
0
        ggml_tensor * ch_k_cd    = get_slice_2d(ctx0, k_cd,    chunk); // [S_k,  CS, 1, H_k * n_seqs]
235
0
        ggml_tensor * ch_v_t     = get_slice_2d(ctx0, v_t,     chunk); // [ CS, S_v, 1, H_v * n_seqs]
236
0
        ggml_tensor * ch_kq      = get_slice_2d(ctx0, kq,      chunk); // [ CS,  CS, 1, H_k * n_seqs]
237
0
        ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k,  CS, 1, H_k * n_seqs]
238
0
        ggml_tensor * ch_kg_t    = get_slice_2d(ctx0, kg_t,    chunk); // [ CS, S_k, 1, H_v * n_seqs]
239
240
        // [CS, S_v, 1, H_v * n_seqs]
241
0
        ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
242
0
        cb(v_t_p, "v_prime", il);
243
244
        // [CS, S_v, 1, H_v * n_seqs]
245
0
        ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p);
246
0
        cb(v_t_new, "v_t_new", il);
247
248
        // [S_v, CS, 1, H_v * n_seqs]
249
0
        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq);
250
0
        cb(v_attn, "v_attn", il);
251
252
        // [S_v, CS, 1, H_v * n_seqs]
253
0
        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
254
0
        cb(attn_inter, "attn_inter", il);
255
256
        // [S_v, CS, 1, H_v * n_seqs]
257
0
        ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn);
258
0
        cb(o_ch, "dnet_add_ch_attn_out", il);
259
260
0
        v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]);
261
262
        // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
263
        // TODO: head broadcast might not work here - probably will need a transpose
264
0
        ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs]
265
266
        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
267
0
        ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk);
268
269
0
        s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t);
270
0
        s_t = ggml_add(ctx0, s_t, kgv);
271
0
        cb(s_t, "dnet_add_ch_state", il);
272
0
    }
273
274
0
    s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);
275
276
    // truncate padded tokens
277
0
    ggml_tensor * o = ggml_view_4d(ctx0, v,
278
0
            S_v, n_tokens, H_v, n_seqs,
279
0
            ggml_row_size(v->type, S_v),
280
0
            ggml_row_size(v->type, S_v * CS * n_chunks),
281
0
            ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
282
0
    o = ggml_permute  (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
283
0
    s = ggml_transpose(ctx0, s_t);
284
0
    cb(s, "output_state", il);
285
286
0
    return {o, s};
287
0
}
288
289
std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_autoregressive(
290
        ggml_tensor * q,
291
        ggml_tensor * k,
292
        ggml_tensor * v,
293
        ggml_tensor * g,
294
        ggml_tensor * b, // beta
295
        ggml_tensor * s, // state
296
0
        int           il) {
297
0
    const int64_t S_k      = q->ne[0];
298
0
    const int64_t H_k      = q->ne[1];
299
0
    const int64_t n_tokens = q->ne[2];
300
0
    const int64_t n_seqs   = q->ne[3];
301
302
0
    const int64_t S_v = v->ne[0];
303
0
    const int64_t H_v = v->ne[1];
304
305
0
    GGML_ASSERT(n_tokens == 1);
306
307
0
    GGML_ASSERT(S_k == S_v);
308
0
    GGML_ASSERT(H_v % H_k == 0);
309
310
0
    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
311
0
    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
312
0
    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
313
314
0
    GGML_ASSERT(g->ne[0] == 1   || g->ne[0] == S_v);
315
0
    GGML_ASSERT(                   g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
316
0
    GGML_ASSERT(b->ne[0] == 1   && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
317
0
    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v      && s->ne[3] == n_seqs);
318
319
0
    const float scale = 1.0f / sqrtf(S_k);
320
321
0
    q = ggml_scale(ctx0, q, scale);
322
323
0
    q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
324
0
    k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
325
0
    v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
326
327
0
    cb(q, "q_in", il);
328
0
    cb(k, "k_in", il);
329
0
    cb(v, "v_in", il);
330
0
    cb(b, "b_in", il);
331
0
    cb(g, "g_in", il);
332
333
    // GDA: [1,  1,  H_v, n_seqs]
334
    // KDA: [1, S_k, H_v, n_seqs]
335
0
    g = ggml_reshape_4d(ctx0, g, 1, g->ne[0], H_v, n_seqs);
336
0
    b = ggml_reshape_4d(ctx0, b, 1,        1, H_v, n_seqs);
337
338
    // [S_v, S_v, H_v, n_seqs]
339
0
    g = ggml_exp(ctx0, g);
340
0
    s = ggml_mul(ctx0, s, g);
341
342
0
    ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
343
344
    // [1, S_v, H_v, n_seqs]
345
0
    ggml_tensor * sk;
346
0
    sk = ggml_mul     (ctx0, s_t, k);
347
0
    sk = ggml_sum_rows(ctx0, sk);
348
349
    // [S_v, 1, H_v, n_seqs]
350
0
    ggml_tensor * d;
351
0
    d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk));
352
0
    d = ggml_mul(ctx0, d, b);
353
354
    // [1, S_v, H_v, n_seqs]
355
0
    ggml_tensor * d_t;
356
0
    d_t = ggml_transpose(ctx0, d);
357
358
    // [S_v, S_v, H_v, n_seqs]
359
0
    ggml_tensor * kd;
360
0
    k  = ggml_repeat(ctx0, k, s);
361
0
    kd = ggml_mul   (ctx0, k, d_t);
362
363
0
    s_t = ggml_add(ctx0, s_t, kd);
364
365
0
    cb(s_t, "dnet_add_ar_state", il);
366
367
0
    ggml_tensor * s_q = ggml_mul     (ctx0, s_t, q);
368
0
    ggml_tensor * o   = ggml_sum_rows(ctx0, s_q);
369
370
0
    o = ggml_permute  (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
371
0
    s = ggml_transpose(ctx0, s_t);           // [S_v, S_v, H_v, n_seqs]
372
373
0
    return {o, s};
374
0
}